Spaces:
Build error
Build error
Sasha (Spock) commited on
Commit ·
7a18c1b
1
Parent(s): c8cd022
materialize ai-toolkit (binaries + bak files removed); fix Dockerfile with PYTHONPATH
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +16 -10
- ai-toolkit/.gitignore +187 -0
- ai-toolkit/.gitmodules +0 -0
- ai-toolkit/FAQ.md +10 -0
- ai-toolkit/LICENSE +21 -0
- ai-toolkit/README.md +316 -0
- ai-toolkit/build_and_push_docker +29 -0
- ai-toolkit/build_and_push_docker_dev +21 -0
- ai-toolkit/config/examples/extract.example.yml +75 -0
- ai-toolkit/config/examples/generate.example.yaml +60 -0
- ai-toolkit/config/examples/mod_lora_scale.yaml +48 -0
- ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
- ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
- ai-toolkit/config/examples/train_flex_redux.yaml +112 -0
- ai-toolkit/config/examples/train_full_fine_tune_flex.yaml +107 -0
- ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml +99 -0
- ai-toolkit/config/examples/train_lora_chroma_24gb.yaml +104 -0
- ai-toolkit/config/examples/train_lora_flex2_24gb.yaml +165 -0
- ai-toolkit/config/examples/train_lora_flex_24gb.yaml +101 -0
- ai-toolkit/config/examples/train_lora_flux_24gb.yaml +96 -0
- ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml +106 -0
- ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
- ai-toolkit/config/examples/train_lora_hidream_48.yaml +112 -0
- ai-toolkit/config/examples/train_lora_lumina.yaml +96 -0
- ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml +94 -0
- ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml +95 -0
- ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml +105 -0
- ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml +102 -0
- ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml +97 -0
- ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml +101 -0
- ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml +90 -0
- ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml +111 -0
- ai-toolkit/config/examples/train_slider.example.yml +230 -0
- ai-toolkit/dgx_instructions.md +84 -0
- ai-toolkit/dgx_requirements.txt +13 -0
- ai-toolkit/docker-compose.yml +25 -0
- ai-toolkit/docker/Dockerfile +108 -0
- ai-toolkit/docker/start.sh +70 -0
- ai-toolkit/extensions/example/ExampleMergeModels.py +129 -0
- ai-toolkit/extensions/example/__init__.py +25 -0
- ai-toolkit/extensions/example/config/config.example.yaml +48 -0
- ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
- ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
- ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
- ai-toolkit/extensions_built_in/advanced_generator/__init__.py +59 -0
- ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
- ai-toolkit/extensions_built_in/audio_models/__init__.py +7 -0
- ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py +1 -0
- ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py +335 -0
- ai-toolkit/extensions_built_in/audio_models/ace_step/src/__init__.py +0 -0
Dockerfile
CHANGED
|
@@ -1,22 +1,28 @@
|
|
| 1 |
-
|
| 2 |
FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime
|
| 3 |
|
| 4 |
WORKDIR /app
|
|
|
|
|
|
|
| 5 |
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
# Install
|
| 14 |
-
RUN pip install --no-cache-dir huggingface_hub
|
| 15 |
|
| 16 |
# Copy training files
|
| 17 |
COPY . /app/
|
| 18 |
|
| 19 |
-
# Pre-download FLUX
|
| 20 |
-
RUN python -c "from huggingface_hub import snapshot_download; snapshot_download('Niansuh/FLUX.1-schnell'
|
| 21 |
|
| 22 |
CMD ["python", "/app/train_cloud.py"]
|
|
|
|
|
|
|
| 1 |
FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# System deps
|
| 6 |
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
| 7 |
|
| 8 |
+
# Pre-baked ai-toolkit (numpy 2.5 / dctorch / torchdiffeq / torchsde / clip forks, 3.14-compatible)
|
| 9 |
+
# Copied from local checkout so HF doesn't have to clone the submodule / run pip install -e .
|
| 10 |
+
COPY ai-toolkit /app/ai-toolkit
|
| 11 |
+
|
| 12 |
+
# Make ai-toolkit importable. Upstream ai-toolkit ships without setup.py / pyproject.toml,
|
| 13 |
+
# so pip install -e . would fail. We add it to PYTHONPATH instead.
|
| 14 |
+
ENV PYTHONPATH=/app:/app/ai-toolkit
|
| 15 |
+
ENV PYTHONUNBUFFERED=1
|
| 16 |
+
ENV HF_HOME=/app/hf_cache
|
| 17 |
+
ENV TRANSFORMERS_CACHE=/app/hf_cache
|
| 18 |
|
| 19 |
+
# Install runtime deps
|
| 20 |
+
RUN pip install --no-cache-dir huggingface_hub hf_transfer
|
| 21 |
|
| 22 |
# Copy training files
|
| 23 |
COPY . /app/
|
| 24 |
|
| 25 |
+
# Pre-download FLUX base + training adapter at build time so they're in the image cache
|
| 26 |
+
RUN python -c "import os; os.environ['HF_HUB_ENABLE_HF_TRANSFER']='1'; from huggingface_hub import snapshot_download; snapshot_download('Niansuh/FLUX.1-schnell'); snapshot_download('ostris/FLUX.1-schnell-training-adapter')"
|
| 27 |
|
| 28 |
CMD ["python", "/app/train_cloud.py"]
|
ai-toolkit/.gitignore
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
.python
|
| 126 |
+
.node
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
.idea/
|
| 163 |
+
|
| 164 |
+
/env.sh
|
| 165 |
+
/models
|
| 166 |
+
/datasets
|
| 167 |
+
/custom/*
|
| 168 |
+
!/custom/.gitkeep
|
| 169 |
+
/.tmp
|
| 170 |
+
/venv.bkp
|
| 171 |
+
/venv.*
|
| 172 |
+
/config/*
|
| 173 |
+
!/config/examples
|
| 174 |
+
!/config/_PUT_YOUR_CONFIGS_HERE).txt
|
| 175 |
+
/output/*
|
| 176 |
+
!/output/.gitkeep
|
| 177 |
+
/extensions/*
|
| 178 |
+
!/extensions/example
|
| 179 |
+
/temp
|
| 180 |
+
/wandb
|
| 181 |
+
.vscode/settings.json
|
| 182 |
+
.DS_Store
|
| 183 |
+
._.DS_Store
|
| 184 |
+
aitk_db.db
|
| 185 |
+
/notes.md
|
| 186 |
+
/data
|
| 187 |
+
.claude
|
ai-toolkit/.gitmodules
ADDED
|
File without changes
|
ai-toolkit/FAQ.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FAQ
|
| 2 |
+
|
| 3 |
+
WIP. Will continue to add things as they are needed.
|
| 4 |
+
|
| 5 |
+
## FLUX.1 Training
|
| 6 |
+
|
| 7 |
+
#### How much VRAM is required to train a lora on FLUX.1?
|
| 8 |
+
|
| 9 |
+
24GB minimum is required.
|
| 10 |
+
|
ai-toolkit/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Ostris, LLC
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
ai-toolkit/README.md
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ostris AI Toolkit
|
| 2 |
+
|
| 3 |
+
AI Toolkit is an easy to use all in one training suite for diffusion models. I try to support all the latest models on consumer grade hardware. Image and video models. It can be run as a GUI or CLI. It is designed to be easy to use but still have every feature imaginable. Free and open source.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
## Supported Models
|
| 8 |
+
|
| 9 |
+
### Image
|
| 10 |
+
- [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (FLUX.1)
|
| 11 |
+
- [black-forest-labs/FLUX.2-dev](https://huggingface.co/black-forest-labs/FLUX.2-dev) (FLUX.2)
|
| 12 |
+
- [black-forest-labs/FLUX.2-klein-base-4B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B) (FLUX.2-klein-base-4B)
|
| 13 |
+
- [black-forest-labs/FLUX.2-klein-base-9B](https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B) (FLUX.2-klein-base-9B)
|
| 14 |
+
- [ostris/Flex.1-alpha](https://huggingface.co/ostris/Flex.1-alpha) (Flex.1)
|
| 15 |
+
- [ostris/Flex.2-preview](https://huggingface.co/ostris/Flex.2-preview) (Flex.2)
|
| 16 |
+
- [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) (Chroma)
|
| 17 |
+
- [Alpha-VLLM/Lumina-Image-2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) (Lumina2)
|
| 18 |
+
- [Qwen/Qwen-Image](https://huggingface.co/Qwen/Qwen-Image) (Qwen-Image)
|
| 19 |
+
- [Qwen/Qwen-Image-2512](https://huggingface.co/Qwen/Qwen-Image-2512) (Qwen-Image-2512)
|
| 20 |
+
- [HiDream-ai/HiDream-I1-Full](https://huggingface.co/HiDream-ai/HiDream-I1-Full) (HiDream I1)
|
| 21 |
+
- [OmniGen2/OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) (OmniGen2)
|
| 22 |
+
- [Tongyi-MAI/Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) (Z-Image Turbo)
|
| 23 |
+
- [Tongyi-MAI/Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image) (Z-Image)
|
| 24 |
+
- [ostris/Z-Image-De-Turbo](https://huggingface.co/ostris/Z-Image-De-Turbo) (Z-Image De-Turbo)
|
| 25 |
+
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) (SDXL)
|
| 26 |
+
- [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) (SD 1.5)
|
| 27 |
+
- [baidu/ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image) (ERNIE-Image)
|
| 28 |
+
- [NucleusAI/Nucleus-Image](https://huggingface.co/NucleusAI/Nucleus-Image) (Nucleus-Image)
|
| 29 |
+
- [HiDream-ai/HiDream-O1-Image](https://huggingface.co/HiDream-ai/HiDream-O1-Image) (HiDream O1)
|
| 30 |
+
- [Photoroom/prxpixel-t2i](https://huggingface.co/Photoroom/prxpixel-t2i) (PRXPixel)
|
| 31 |
+
|
| 32 |
+
### Instruction / Edit
|
| 33 |
+
- [black-forest-labs/FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) (FLUX.1-Kontext-dev)
|
| 34 |
+
- [Qwen/Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit) (Qwen-Image-Edit)
|
| 35 |
+
- [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) (Qwen-Image-Edit-2509)
|
| 36 |
+
- [Qwen/Qwen-Image-Edit-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) (Qwen-Image-Edit-2511)
|
| 37 |
+
- [HiDream-ai/HiDream-E1-1](https://huggingface.co/HiDream-ai/HiDream-E1-1) (HiDream E1)
|
| 38 |
+
|
| 39 |
+
### Video
|
| 40 |
+
- [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) (Wan 2.1 1.3B)
|
| 41 |
+
- [Wan-AI/Wan2.1-I2V-14B-480P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers) (Wan 2.1 I2V 14B-480P)
|
| 42 |
+
- [Wan-AI/Wan2.1-I2V-14B-720P-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers) (Wan 2.1 I2V 14B-720P)
|
| 43 |
+
- [Wan-AI/Wan2.1-T2V-14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) (Wan 2.1 14B)
|
| 44 |
+
- [Wan-AI/Wan2.2-T2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) (Wan 2.2 14B)
|
| 45 |
+
- [Wan-AI/Wan2.2-I2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) (Wan 2.2 I2V 14B)
|
| 46 |
+
- [Wan-AI/Wan2.2-TI2V-5B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) (Wan 2.2 TI2V 5B)
|
| 47 |
+
- [Lightricks/LTX-2](https://huggingface.co/Lightricks/LTX-2) (LTX-2)
|
| 48 |
+
- [Lightricks/LTX-2.3](https://huggingface.co/Lightricks/LTX-2.3) (LTX-2.3)
|
| 49 |
+
- [krea/Krea-2-Raw](https://huggingface.co/krea/Krea-2-Raw) (Krea 2)
|
| 50 |
+
|
| 51 |
+
### Audio
|
| 52 |
+
- [ACE-Step/Ace-Step1.5](https://huggingface.co/ACE-Step/Ace-Step1.5) (Ace Step 1.5)
|
| 53 |
+
- [ACE-Step/acestep-v15-xl-base](https://huggingface.co/ACE-Step/acestep-v15-xl-base) (Ace Step 1.5 XL)
|
| 54 |
+
|
| 55 |
+
### Experimental
|
| 56 |
+
- [lodestones/Zeta-Chroma](https://huggingface.co/lodestones/Zeta-Chroma) (Zeta Chroma)
|
| 57 |
+
- [ideogram-ai/ideogram-4-fp8](https://huggingface.co/ideogram-ai/ideogram-4-fp8) (Ideogram 4 FP8)
|
| 58 |
+
|
| 59 |
+
## Installation
|
| 60 |
+
|
| 61 |
+
Requirements:
|
| 62 |
+
- python >=3.10 (3.12 recommended)
|
| 63 |
+
- Nvidia GPU with enough ram to do what you need
|
| 64 |
+
- python venv
|
| 65 |
+
- git
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
Linux:
|
| 69 |
+
```bash
|
| 70 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
| 71 |
+
cd ai-toolkit
|
| 72 |
+
python3 -m venv venv
|
| 73 |
+
source venv/bin/activate
|
| 74 |
+
# install torch first
|
| 75 |
+
pip3 install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
|
| 76 |
+
pip3 install -r requirements.txt
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
For devices running **DGX OS** (including DGX Spark), follow [these](dgx_instructions.md) instructions.
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
Windows:
|
| 83 |
+
|
| 84 |
+
If you are having issues with Windows. I recommend using the easy install script at [https://github.com/Tavris1/AI-Toolkit-Easy-Install](https://github.com/Tavris1/AI-Toolkit-Easy-Install)
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
| 88 |
+
cd ai-toolkit
|
| 89 |
+
python -m venv venv
|
| 90 |
+
.\venv\Scripts\activate
|
| 91 |
+
pip install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
|
| 92 |
+
pip install -r requirements.txt
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
MacOS:
|
| 96 |
+
|
| 97 |
+
Experimental support for Silicon Macs is available. I do not have a Mac with enough RAM to fully test this
|
| 98 |
+
so please let me know if there are issues. There is a convience script to install and run on MacOS
|
| 99 |
+
locates at `./run_mac.zsh` that will install the dependencies locally and run the UI. To run this,
|
| 100 |
+
do the following:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
| 104 |
+
cd ai-toolkit
|
| 105 |
+
chmod +x run_mac.zsh
|
| 106 |
+
./run_mac.zsh
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# AI Toolkit UI
|
| 111 |
+
|
| 112 |
+
<img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
|
| 113 |
+
|
| 114 |
+
The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server.
|
| 115 |
+
|
| 116 |
+
## Running the UI
|
| 117 |
+
|
| 118 |
+
Requirements:
|
| 119 |
+
- Node.js > 20
|
| 120 |
+
|
| 121 |
+
The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below
|
| 122 |
+
will install / update the UI and it's dependencies and start the UI.
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
cd ui
|
| 126 |
+
npm run build_and_start
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
|
| 130 |
+
|
| 131 |
+
## Securing the UI
|
| 132 |
+
|
| 133 |
+
If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token.
|
| 134 |
+
You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access
|
| 135 |
+
the UI. You can set this when starting the UI like so:
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
# Linux
|
| 139 |
+
AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start
|
| 140 |
+
|
| 141 |
+
# Windows
|
| 142 |
+
set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start
|
| 143 |
+
|
| 144 |
+
# Windows Powershell
|
| 145 |
+
$env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Training
|
| 149 |
+
1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
|
| 150 |
+
2. Edit the file following the comments in the file
|
| 151 |
+
3. Run the file like so `python run.py config/whatever_you_want.yml`
|
| 152 |
+
|
| 153 |
+
A folder with the name and the training folder from the config file will be created when you start. It will have all
|
| 154 |
+
checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
|
| 155 |
+
from the last checkpoint.
|
| 156 |
+
|
| 157 |
+
IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
|
| 158 |
+
|
| 159 |
+
### Need help?
|
| 160 |
+
|
| 161 |
+
Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
|
| 162 |
+
and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
|
| 163 |
+
and I will answer when I can.
|
| 164 |
+
|
| 165 |
+
## Ostris Cloud
|
| 166 |
+
|
| 167 |
+
You can use many cloud providers to rent GPUs. If you want to help support this project in the largest way possible, please consider using [Ostris Cloud](https://cloud.ostris.com). Ostris Cloud is owned and operated by me, Ostris, and every dollar earned goes directly back into funding the development of this project.
|
| 168 |
+
|
| 169 |
+
<a href="https://cloud.ostris.com" target="_blank"><img src="https://cloud.ostris.com/api/og" alt="Ostris Cloud" style="max-width:100%;width:600px;height:auto;"></a>
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
## Training in RunPod
|
| 173 |
+
If you would like to use Runpod, but have not signed up yet, please consider using [my Runpod affiliate link](https://runpod.io?ref=h0y9jyr2) to help support this project.
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
I maintain an official Runpod Pod template here which can be accessed [here](https://console.runpod.io/deploy?template=0fqzfjy6f3&ref=h0y9jyr2).
|
| 177 |
+
|
| 178 |
+
I have also created a short video showing how to get started using AI Toolkit with Runpod [here](https://youtu.be/HBNeS-F6Zz8).
|
| 179 |
+
|
| 180 |
+
## Training in Modal
|
| 181 |
+
|
| 182 |
+
### 1. Setup
|
| 183 |
+
#### ai-toolkit:
|
| 184 |
+
```
|
| 185 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
| 186 |
+
cd ai-toolkit
|
| 187 |
+
git submodule update --init --recursive
|
| 188 |
+
python -m venv venv
|
| 189 |
+
source venv/bin/activate
|
| 190 |
+
pip install torch
|
| 191 |
+
pip install -r requirements.txt
|
| 192 |
+
pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
|
| 193 |
+
```
|
| 194 |
+
#### Modal:
|
| 195 |
+
- Run `pip install modal` to install the modal Python package.
|
| 196 |
+
- Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
|
| 197 |
+
|
| 198 |
+
#### Hugging Face:
|
| 199 |
+
- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
|
| 200 |
+
- Run `huggingface-cli login` and paste your token.
|
| 201 |
+
|
| 202 |
+
### 2. Upload your dataset
|
| 203 |
+
- Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
|
| 204 |
+
|
| 205 |
+
### 3. Configs
|
| 206 |
+
- Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
|
| 207 |
+
- Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
|
| 208 |
+
|
| 209 |
+
### 4. Edit run_modal.py
|
| 210 |
+
- Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
|
| 211 |
+
|
| 212 |
+
```
|
| 213 |
+
code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
|
| 214 |
+
```
|
| 215 |
+
- Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
|
| 216 |
+
|
| 217 |
+
### 5. Training
|
| 218 |
+
- Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
|
| 219 |
+
- You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
|
| 220 |
+
- Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
|
| 221 |
+
|
| 222 |
+
### 6. Saving the model
|
| 223 |
+
- Check contents of the volume by running `modal volume ls flux-lora-models`.
|
| 224 |
+
- Download the content by running `modal volume get flux-lora-models your-model-name`.
|
| 225 |
+
- Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
|
| 226 |
+
|
| 227 |
+
### Screenshot from Modal
|
| 228 |
+
|
| 229 |
+
<img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## Dataset Preparation
|
| 234 |
+
|
| 235 |
+
Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
|
| 236 |
+
formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
|
| 237 |
+
but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
|
| 238 |
+
You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
|
| 239 |
+
replaced.
|
| 240 |
+
|
| 241 |
+
Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
|
| 242 |
+
The loader will automatically resize them and can handle varying aspect ratios.
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
## Training Specific Layers
|
| 246 |
+
|
| 247 |
+
To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
|
| 248 |
+
used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
|
| 249 |
+
network kwargs like so:
|
| 250 |
+
|
| 251 |
+
```yaml
|
| 252 |
+
network:
|
| 253 |
+
type: "lora"
|
| 254 |
+
linear: 128
|
| 255 |
+
linear_alpha: 128
|
| 256 |
+
network_kwargs:
|
| 257 |
+
only_if_contains:
|
| 258 |
+
- "transformer.single_transformer_blocks.7.proj_out"
|
| 259 |
+
- "transformer.single_transformer_blocks.20.proj_out"
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
|
| 263 |
+
the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
|
| 264 |
+
For instance to only train the `single_transformer` for FLUX.1, you can use the following:
|
| 265 |
+
|
| 266 |
+
```yaml
|
| 267 |
+
network:
|
| 268 |
+
type: "lora"
|
| 269 |
+
linear: 128
|
| 270 |
+
linear_alpha: 128
|
| 271 |
+
network_kwargs:
|
| 272 |
+
only_if_contains:
|
| 273 |
+
- "transformer.single_transformer_blocks."
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
```yaml
|
| 280 |
+
network:
|
| 281 |
+
type: "lora"
|
| 282 |
+
linear: 128
|
| 283 |
+
linear_alpha: 128
|
| 284 |
+
network_kwargs:
|
| 285 |
+
ignore_if_contains:
|
| 286 |
+
- "transformer.single_transformer_blocks."
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
|
| 290 |
+
if will be ignored.
|
| 291 |
+
|
| 292 |
+
## LoKr Training
|
| 293 |
+
|
| 294 |
+
To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so:
|
| 295 |
+
|
| 296 |
+
```yaml
|
| 297 |
+
network:
|
| 298 |
+
type: "lokr"
|
| 299 |
+
lokr_full_rank: true
|
| 300 |
+
lokr_factor: 8
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
Everything else should work the same including layer targeting.
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
## Support My Work
|
| 307 |
+
|
| 308 |
+
If you enjoy my projects or use them commercially, please consider sponsoring me. Every bit helps! 💖
|
| 309 |
+
|
| 310 |
+
<a href="https://ostris.com/sponsors" target="_blank"><img src="https://ostris.com/wp-content/uploads/2025/05/support-banner2.png" alt="Support my work" style="max-width:100%;height:auto;"></a>
|
| 311 |
+
|
| 312 |
+
### Current Sponsors
|
| 313 |
+
|
| 314 |
+
All of these people / organizations are the ones who selflessly make this project possible. Thank you!!
|
| 315 |
+
|
| 316 |
+
<a href="https://ostris.com/sponsors"><img src="https://ostris.com/sponsors.svg" alt="Sponsors" style="width:100%;height:auto;"></a>
|
ai-toolkit/build_and_push_docker
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Extract version from version.py
|
| 4 |
+
if [ -f "version.py" ]; then
|
| 5 |
+
VERSION=$(python3 -c "from version import VERSION; print(VERSION)")
|
| 6 |
+
echo "Building version: $VERSION"
|
| 7 |
+
else
|
| 8 |
+
echo "Error: version.py not found. Please create a version.py file with VERSION defined."
|
| 9 |
+
exit 1
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
| 13 |
+
echo "Building version: $VERSION and latest"
|
| 14 |
+
# wait 2 seconds
|
| 15 |
+
sleep 2
|
| 16 |
+
|
| 17 |
+
# Build the image with cache busting
|
| 18 |
+
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
|
| 19 |
+
|
| 20 |
+
# Tag with version and latest
|
| 21 |
+
docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
|
| 22 |
+
docker tag aitoolkit:$VERSION ostris/aitoolkit:latest
|
| 23 |
+
|
| 24 |
+
# Push both tags
|
| 25 |
+
echo "Pushing images to Docker Hub..."
|
| 26 |
+
docker push ostris/aitoolkit:$VERSION
|
| 27 |
+
docker push ostris/aitoolkit:latest
|
| 28 |
+
|
| 29 |
+
echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest"
|
ai-toolkit/build_and_push_docker_dev
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
VERSION=dev
|
| 4 |
+
GIT_COMMIT=dev
|
| 5 |
+
|
| 6 |
+
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
| 7 |
+
echo "Building version: $VERSION"
|
| 8 |
+
# wait 2 seconds
|
| 9 |
+
sleep 2
|
| 10 |
+
|
| 11 |
+
# Build the image with cache busting
|
| 12 |
+
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
|
| 13 |
+
|
| 14 |
+
# Tag with version and latest
|
| 15 |
+
docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
|
| 16 |
+
|
| 17 |
+
# Push both tags
|
| 18 |
+
echo "Pushing images to Docker Hub..."
|
| 19 |
+
docker push ostris/aitoolkit:$VERSION
|
| 20 |
+
|
| 21 |
+
echo "Successfully built and pushed ostris/aitoolkit:$VERSION"
|
ai-toolkit/config/examples/extract.example.yml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# this is in yaml format. You can use json if you prefer
|
| 3 |
+
# I like both but yaml is easier to read and write
|
| 4 |
+
# plus it has comments which is nice for documentation
|
| 5 |
+
job: extract # tells the runner what to do
|
| 6 |
+
config:
|
| 7 |
+
# the name will be used to create a folder in the output folder
|
| 8 |
+
# it will also replace any [name] token in the rest of this config
|
| 9 |
+
name: name_of_your_model
|
| 10 |
+
# can be hugging face model, a .ckpt, or a .safetensors
|
| 11 |
+
base_model: "/path/to/base/model.safetensors"
|
| 12 |
+
# can be hugging face model, a .ckpt, or a .safetensors
|
| 13 |
+
extract_model: "/path/to/model/to/extract/trained.safetensors"
|
| 14 |
+
# we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
|
| 15 |
+
output_folder: "/path/to/output/folder"
|
| 16 |
+
is_v2: false
|
| 17 |
+
dtype: fp16 # saved dtype
|
| 18 |
+
device: cpu # cpu, cuda:0, etc
|
| 19 |
+
|
| 20 |
+
# processes can be chained like this to run multiple in a row
|
| 21 |
+
# they must all use same models above, but great for testing different
|
| 22 |
+
# sizes and typed of extractions. It is much faster as we already have the models loaded
|
| 23 |
+
process:
|
| 24 |
+
# process 1
|
| 25 |
+
- type: locon # locon or lora (locon is lycoris)
|
| 26 |
+
filename: "[name]_64_32.safetensors" # will be put in output folder
|
| 27 |
+
dtype: fp16
|
| 28 |
+
mode: fixed
|
| 29 |
+
linear: 64
|
| 30 |
+
conv: 32
|
| 31 |
+
|
| 32 |
+
# process 2
|
| 33 |
+
- type: locon
|
| 34 |
+
output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
|
| 35 |
+
mode: ratio
|
| 36 |
+
linear: 0.2
|
| 37 |
+
conv: 0.2
|
| 38 |
+
|
| 39 |
+
# process 3
|
| 40 |
+
- type: locon
|
| 41 |
+
filename: "[name]_ratio_02.safetensors"
|
| 42 |
+
mode: quantile
|
| 43 |
+
linear: 0.5
|
| 44 |
+
conv: 0.5
|
| 45 |
+
|
| 46 |
+
# process 4
|
| 47 |
+
- type: lora # traditional lora extraction (lierla) with linear layers only
|
| 48 |
+
filename: "[name]_4.safetensors"
|
| 49 |
+
mode: fixed # fixed, ratio, quantile supported for lora as well
|
| 50 |
+
linear: 4 # lora dim or rank
|
| 51 |
+
# no conv for lora
|
| 52 |
+
|
| 53 |
+
# process 5
|
| 54 |
+
- type: lora
|
| 55 |
+
filename: "[name]_q05.safetensors"
|
| 56 |
+
mode: quantile
|
| 57 |
+
linear: 0.5
|
| 58 |
+
|
| 59 |
+
# you can put any information you want here, and it will be saved in the model
|
| 60 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
| 61 |
+
# in the metadata. The software will include this plus some other information
|
| 62 |
+
meta:
|
| 63 |
+
name: "[name]" # [name] gets replaced with the name above
|
| 64 |
+
description: A short description of your model
|
| 65 |
+
trigger_words:
|
| 66 |
+
- put
|
| 67 |
+
- trigger
|
| 68 |
+
- words
|
| 69 |
+
- here
|
| 70 |
+
version: '0.1'
|
| 71 |
+
creator:
|
| 72 |
+
name: Your Name
|
| 73 |
+
email: your@email.com
|
| 74 |
+
website: https://yourwebsite.com
|
| 75 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
ai-toolkit/config/examples/generate.example.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
|
| 3 |
+
job: generate # tells the runner what to do
|
| 4 |
+
config:
|
| 5 |
+
name: "generate" # this is not really used anywhere currently but required by runner
|
| 6 |
+
process:
|
| 7 |
+
# process 1
|
| 8 |
+
- type: to_folder # process images to a folder
|
| 9 |
+
output_folder: "output/gen"
|
| 10 |
+
device: cuda:0 # cpu, cuda:0, etc
|
| 11 |
+
generate:
|
| 12 |
+
# these are your defaults you can override most of them with flags
|
| 13 |
+
sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
|
| 14 |
+
width: 1024
|
| 15 |
+
height: 1024
|
| 16 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
|
| 17 |
+
seed: -1 # -1 is random
|
| 18 |
+
guidance_scale: 7
|
| 19 |
+
sample_steps: 20
|
| 20 |
+
ext: ".png" # .png, .jpg, .jpeg, .webp
|
| 21 |
+
|
| 22 |
+
# here ate the flags you can use for prompts. Always start with
|
| 23 |
+
# your prompt first then add these flags after. You can use as many
|
| 24 |
+
# like
|
| 25 |
+
# photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
|
| 26 |
+
# we will try to support all sd-scripts flags where we can
|
| 27 |
+
|
| 28 |
+
# FROM SD-SCRIPTS
|
| 29 |
+
# --n Treat everything until the next option as a negative prompt.
|
| 30 |
+
# --w Specify the width of the generated image.
|
| 31 |
+
# --h Specify the height of the generated image.
|
| 32 |
+
# --d Specify the seed for the generated image.
|
| 33 |
+
# --l Specify the CFG scale for the generated image.
|
| 34 |
+
# --s Specify the number of steps during generation.
|
| 35 |
+
|
| 36 |
+
# OURS and some QOL additions
|
| 37 |
+
# --p2 Prompt for the second text encoder (SDXL only)
|
| 38 |
+
# --n2 Negative prompt for the second text encoder (SDXL only)
|
| 39 |
+
# --gr Specify the guidance rescale for the generated image (SDXL only)
|
| 40 |
+
# --seed Specify the seed for the generated image same as --d
|
| 41 |
+
# --cfg Specify the CFG scale for the generated image same as --l
|
| 42 |
+
# --steps Specify the number of steps during generation same as --s
|
| 43 |
+
|
| 44 |
+
prompt_file: false # if true a txt file will be created next to images with prompt strings used
|
| 45 |
+
# prompts can also be a path to a text file with one prompt per line
|
| 46 |
+
# prompts: "/path/to/prompts.txt"
|
| 47 |
+
prompts:
|
| 48 |
+
- "photo of batman"
|
| 49 |
+
- "photo of superman"
|
| 50 |
+
- "photo of spiderman"
|
| 51 |
+
- "photo of a superhero --n batman superman spiderman"
|
| 52 |
+
|
| 53 |
+
model:
|
| 54 |
+
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
| 55 |
+
# name_or_path: "runwayml/stable-diffusion-v1-5"
|
| 56 |
+
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
|
| 57 |
+
is_v2: false # for v2 models
|
| 58 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
| 59 |
+
is_xl: false # for SDXL models
|
| 60 |
+
dtype: bf16
|
ai-toolkit/config/examples/mod_lora_scale.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: mod
|
| 3 |
+
config:
|
| 4 |
+
name: name_of_your_model_v1
|
| 5 |
+
process:
|
| 6 |
+
- type: rescale_lora
|
| 7 |
+
# path to your current lora model
|
| 8 |
+
input_path: "/path/to/lora/lora.safetensors"
|
| 9 |
+
# output path for your new lora model, can be the same as input_path to replace
|
| 10 |
+
output_path: "/path/to/lora/output_lora_v1.safetensors"
|
| 11 |
+
# replaces meta with the meta below (plus minimum meta fields)
|
| 12 |
+
# if false, we will leave the meta alone except for updating hashes (sd-script hashes)
|
| 13 |
+
replace_meta: true
|
| 14 |
+
# how to adjust, we can scale the up_down weights or the alpha
|
| 15 |
+
# up_down is the default and probably the best, they will both net the same outputs
|
| 16 |
+
# would only affect rare NaN cases and maybe merging with old merge tools
|
| 17 |
+
scale_target: 'up_down'
|
| 18 |
+
# precision to save, fp16 is the default and standard
|
| 19 |
+
save_dtype: fp16
|
| 20 |
+
# current_weight is the ideal weight you use as a multiplier when using the lora
|
| 21 |
+
# IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
|
| 22 |
+
# you can do negatives here too if you want to flip the lora
|
| 23 |
+
current_weight: 6.0
|
| 24 |
+
# target_weight is the ideal weight you use as a multiplier when using the lora
|
| 25 |
+
# instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
|
| 26 |
+
# we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
|
| 27 |
+
target_weight: 1.0
|
| 28 |
+
|
| 29 |
+
# base model for the lora
|
| 30 |
+
# this is just used to add meta so automatic111 knows which model it is for
|
| 31 |
+
# assume v1.5 if these are not set
|
| 32 |
+
is_xl: false
|
| 33 |
+
is_v2: false
|
| 34 |
+
meta:
|
| 35 |
+
# this is only used if you set replace_meta to true above
|
| 36 |
+
name: "[name]" # [name] gets replaced with the name above
|
| 37 |
+
description: A short description of your lora
|
| 38 |
+
trigger_words:
|
| 39 |
+
- put
|
| 40 |
+
- trigger
|
| 41 |
+
- words
|
| 42 |
+
- here
|
| 43 |
+
version: '0.1'
|
| 44 |
+
creator:
|
| 45 |
+
name: Your Name
|
| 46 |
+
email: your@email.com
|
| 47 |
+
website: https://yourwebsite.com
|
| 48 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
ai-toolkit/config/examples/modal/modal_train_lora_flux_24gb.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flux_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
datasets:
|
| 25 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 26 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 27 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 28 |
+
# on windows, escape back slashes with another backslash so
|
| 29 |
+
# "C:\\path\\to\\images\\folder"
|
| 30 |
+
# your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
|
| 31 |
+
- folder_path: "/root/ai-toolkit/your-dataset"
|
| 32 |
+
caption_ext: "txt"
|
| 33 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 34 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 35 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 36 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
| 37 |
+
train:
|
| 38 |
+
batch_size: 1
|
| 39 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 40 |
+
gradient_accumulation_steps: 1
|
| 41 |
+
train_unet: true
|
| 42 |
+
train_text_encoder: false # probably won't work with flux
|
| 43 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 44 |
+
noise_scheduler: "flowmatch" # for training only
|
| 45 |
+
optimizer: "adamw8bit"
|
| 46 |
+
lr: 1e-4
|
| 47 |
+
# uncomment this to skip the pre training sample
|
| 48 |
+
# skip_first_sample: true
|
| 49 |
+
# uncomment to completely disable sampling
|
| 50 |
+
# disable_sampling: true
|
| 51 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 52 |
+
# linear_timesteps: true
|
| 53 |
+
|
| 54 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 55 |
+
ema_config:
|
| 56 |
+
use_ema: true
|
| 57 |
+
ema_decay: 0.99
|
| 58 |
+
|
| 59 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
| 60 |
+
dtype: bf16
|
| 61 |
+
model:
|
| 62 |
+
# huggingface model name or path
|
| 63 |
+
# if you get an error, or get stuck while downloading,
|
| 64 |
+
# check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
|
| 65 |
+
# place it like "/root/ai-toolkit/FLUX.1-dev"
|
| 66 |
+
name_or_path: "black-forest-labs/FLUX.1-dev"
|
| 67 |
+
is_flux: true
|
| 68 |
+
quantize: true # run 8bit mixed precision
|
| 69 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
| 70 |
+
sample:
|
| 71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 72 |
+
sample_every: 250 # sample every this many steps
|
| 73 |
+
width: 1024
|
| 74 |
+
height: 1024
|
| 75 |
+
prompts:
|
| 76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 85 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 88 |
+
neg: "" # not used on flux
|
| 89 |
+
seed: 42
|
| 90 |
+
walk_seed: true
|
| 91 |
+
guidance_scale: 4
|
| 92 |
+
sample_steps: 20
|
| 93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 94 |
+
meta:
|
| 95 |
+
name: "[name]"
|
| 96 |
+
version: '1.0'
|
ai-toolkit/config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flux_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
datasets:
|
| 25 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 26 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 27 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 28 |
+
# on windows, escape back slashes with another backslash so
|
| 29 |
+
# "C:\\path\\to\\images\\folder"
|
| 30 |
+
# your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
|
| 31 |
+
- folder_path: "/root/ai-toolkit/your-dataset"
|
| 32 |
+
caption_ext: "txt"
|
| 33 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 34 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 35 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 36 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
| 37 |
+
train:
|
| 38 |
+
batch_size: 1
|
| 39 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 40 |
+
gradient_accumulation_steps: 1
|
| 41 |
+
train_unet: true
|
| 42 |
+
train_text_encoder: false # probably won't work with flux
|
| 43 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 44 |
+
noise_scheduler: "flowmatch" # for training only
|
| 45 |
+
optimizer: "adamw8bit"
|
| 46 |
+
lr: 1e-4
|
| 47 |
+
# uncomment this to skip the pre training sample
|
| 48 |
+
# skip_first_sample: true
|
| 49 |
+
# uncomment to completely disable sampling
|
| 50 |
+
# disable_sampling: true
|
| 51 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 52 |
+
# linear_timesteps: true
|
| 53 |
+
|
| 54 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 55 |
+
ema_config:
|
| 56 |
+
use_ema: true
|
| 57 |
+
ema_decay: 0.99
|
| 58 |
+
|
| 59 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
| 60 |
+
dtype: bf16
|
| 61 |
+
model:
|
| 62 |
+
# huggingface model name or path
|
| 63 |
+
# if you get an error, or get stuck while downloading,
|
| 64 |
+
# check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
|
| 65 |
+
# place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
|
| 66 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
| 67 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
|
| 68 |
+
is_flux: true
|
| 69 |
+
quantize: true # run 8bit mixed precision
|
| 70 |
+
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
|
| 71 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
| 72 |
+
sample:
|
| 73 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 74 |
+
sample_every: 250 # sample every this many steps
|
| 75 |
+
width: 1024
|
| 76 |
+
height: 1024
|
| 77 |
+
prompts:
|
| 78 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 79 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 80 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 81 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 82 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 83 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 84 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 85 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 86 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 87 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 88 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 89 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 90 |
+
neg: "" # not used on flux
|
| 91 |
+
seed: 42
|
| 92 |
+
walk_seed: true
|
| 93 |
+
guidance_scale: 1 # schnell does not do guidance
|
| 94 |
+
sample_steps: 4 # 1 - 4 works well
|
| 95 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 96 |
+
meta:
|
| 97 |
+
name: "[name]"
|
| 98 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_flex_redux.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flex_redux_finetune_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
adapter:
|
| 14 |
+
type: "redux"
|
| 15 |
+
# you can finetune an existing adapter or start from scratch. Set to null to start from scratch
|
| 16 |
+
name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors'
|
| 17 |
+
# name_or_path: null
|
| 18 |
+
# image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter
|
| 19 |
+
image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter
|
| 20 |
+
# image_encoder_arch: 'siglip' # for Flux.1
|
| 21 |
+
image_encoder_arch: 'siglip2'
|
| 22 |
+
# You need a control input for each sample. Best to do squares for both images
|
| 23 |
+
test_img_path:
|
| 24 |
+
- "/path/to/x_01.jpg"
|
| 25 |
+
- "/path/to/x_02.jpg"
|
| 26 |
+
- "/path/to/x_03.jpg"
|
| 27 |
+
- "/path/to/x_04.jpg"
|
| 28 |
+
- "/path/to/x_05.jpg"
|
| 29 |
+
- "/path/to/x_06.jpg"
|
| 30 |
+
- "/path/to/x_07.jpg"
|
| 31 |
+
- "/path/to/x_08.jpg"
|
| 32 |
+
- "/path/to/x_09.jpg"
|
| 33 |
+
- "/path/to/x_10.jpg"
|
| 34 |
+
clip_layer: 'last_hidden_state'
|
| 35 |
+
train: true
|
| 36 |
+
save:
|
| 37 |
+
dtype: bf16 # precision to save
|
| 38 |
+
save_every: 250 # save every this many steps
|
| 39 |
+
max_step_saves_to_keep: 4
|
| 40 |
+
datasets:
|
| 41 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 42 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 43 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 44 |
+
# on windows, escape back slashes with another backslash so
|
| 45 |
+
# "C:\\path\\to\\images\\folder"
|
| 46 |
+
- folder_path: "/path/to/images/folder"
|
| 47 |
+
# clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter)
|
| 48 |
+
# for normal redux, we are just recreating the same image, so you can use the same folder path above
|
| 49 |
+
clip_image_path: "/path/to/control/images/folder"
|
| 50 |
+
caption_ext: "txt"
|
| 51 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 52 |
+
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
| 53 |
+
train:
|
| 54 |
+
# this is what I used for the 24GB card, but feel free to adjust
|
| 55 |
+
# total batch size is 6 here
|
| 56 |
+
batch_size: 3
|
| 57 |
+
gradient_accumulation: 2
|
| 58 |
+
|
| 59 |
+
# captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
|
| 60 |
+
unload_text_encoder: true
|
| 61 |
+
|
| 62 |
+
loss_type: "mse"
|
| 63 |
+
train_unet: true
|
| 64 |
+
train_text_encoder: false
|
| 65 |
+
steps: 4000000 # I set this very high and stop when I like the results
|
| 66 |
+
content_or_style: balanced # content, style, balanced
|
| 67 |
+
gradient_checkpointing: true
|
| 68 |
+
noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a"
|
| 69 |
+
timestep_type: "flux_shift"
|
| 70 |
+
optimizer: "adamw8bit"
|
| 71 |
+
lr: 1e-4
|
| 72 |
+
|
| 73 |
+
# this is for Flex.1, comment this out for FLUX.1-dev
|
| 74 |
+
bypass_guidance_embedding: true
|
| 75 |
+
|
| 76 |
+
dtype: bf16
|
| 77 |
+
ema_config:
|
| 78 |
+
use_ema: true
|
| 79 |
+
ema_decay: 0.99
|
| 80 |
+
model:
|
| 81 |
+
name_or_path: "ostris/Flex.1-alpha"
|
| 82 |
+
is_flux: true
|
| 83 |
+
quantize: true
|
| 84 |
+
text_encoder_bits: 8
|
| 85 |
+
sample:
|
| 86 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 87 |
+
sample_every: 250 # sample every this many steps
|
| 88 |
+
width: 1024
|
| 89 |
+
height: 1024
|
| 90 |
+
# I leave half blank to test prompt and unprompted
|
| 91 |
+
prompts:
|
| 92 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 93 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 94 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 95 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 96 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 97 |
+
- ""
|
| 98 |
+
- ""
|
| 99 |
+
- ""
|
| 100 |
+
- ""
|
| 101 |
+
- ""
|
| 102 |
+
neg: ""
|
| 103 |
+
seed: 42
|
| 104 |
+
walk_seed: true
|
| 105 |
+
guidance_scale: 4
|
| 106 |
+
sample_steps: 25
|
| 107 |
+
network_multiplier: 1.0
|
| 108 |
+
|
| 109 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 110 |
+
meta:
|
| 111 |
+
name: "[name]"
|
| 112 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_full_fine_tune_flex.yaml
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# This configuration requires 48GB of VRAM or more to operate
|
| 3 |
+
job: extension
|
| 4 |
+
config:
|
| 5 |
+
# this name will be the folder and filename name
|
| 6 |
+
name: "my_first_flex_finetune_v1"
|
| 7 |
+
process:
|
| 8 |
+
- type: 'sd_trainer'
|
| 9 |
+
# root folder to save training sessions/samples/weights
|
| 10 |
+
training_folder: "output"
|
| 11 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 12 |
+
# performance_log_every: 1000
|
| 13 |
+
device: cuda:0
|
| 14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
save:
|
| 18 |
+
dtype: bf16 # precision to save
|
| 19 |
+
save_every: 250 # save every this many steps
|
| 20 |
+
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
| 21 |
+
save_format: 'diffusers' # 'diffusers'
|
| 22 |
+
datasets:
|
| 23 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 24 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 25 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 26 |
+
# on windows, escape back slashes with another backslash so
|
| 27 |
+
# "C:\\path\\to\\images\\folder"
|
| 28 |
+
- folder_path: "/path/to/images/folder"
|
| 29 |
+
caption_ext: "txt"
|
| 30 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 31 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 32 |
+
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 33 |
+
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
| 34 |
+
train:
|
| 35 |
+
batch_size: 1
|
| 36 |
+
# IMPORTANT! For Flex, you must bypass the guidance embedder during training
|
| 37 |
+
bypass_guidance_embedding: true
|
| 38 |
+
|
| 39 |
+
# can be 'sigmoid', 'linear', or 'lognorm_blend'
|
| 40 |
+
timestep_type: 'sigmoid'
|
| 41 |
+
|
| 42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 43 |
+
gradient_accumulation: 1
|
| 44 |
+
train_unet: true
|
| 45 |
+
train_text_encoder: false # probably won't work with flex
|
| 46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 47 |
+
noise_scheduler: "flowmatch" # for training only
|
| 48 |
+
optimizer: "adafactor"
|
| 49 |
+
lr: 3e-5
|
| 50 |
+
|
| 51 |
+
# Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
|
| 52 |
+
# 0.1 is 10% of paramiters active at easc step. Only works with adafactor
|
| 53 |
+
|
| 54 |
+
# do_paramiter_swapping: true
|
| 55 |
+
# paramiter_swapping_factor: 0.9
|
| 56 |
+
|
| 57 |
+
# uncomment this to skip the pre training sample
|
| 58 |
+
# skip_first_sample: true
|
| 59 |
+
# uncomment to completely disable sampling
|
| 60 |
+
# disable_sampling: true
|
| 61 |
+
|
| 62 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
| 63 |
+
ema_config:
|
| 64 |
+
use_ema: true
|
| 65 |
+
ema_decay: 0.99
|
| 66 |
+
|
| 67 |
+
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
| 68 |
+
dtype: bf16
|
| 69 |
+
model:
|
| 70 |
+
# huggingface model name or path
|
| 71 |
+
name_or_path: "ostris/Flex.1-alpha"
|
| 72 |
+
is_flux: true # flex is flux architecture
|
| 73 |
+
# full finetuning quantized models is a crapshoot and results in subpar outputs
|
| 74 |
+
# quantize: true
|
| 75 |
+
# you can quantize just the T5 text encoder here to save vram
|
| 76 |
+
quantize_te: true
|
| 77 |
+
# only train the transformer blocks
|
| 78 |
+
only_if_contains:
|
| 79 |
+
- "transformer.transformer_blocks."
|
| 80 |
+
- "transformer.single_transformer_blocks."
|
| 81 |
+
sample:
|
| 82 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 83 |
+
sample_every: 250 # sample every this many steps
|
| 84 |
+
width: 1024
|
| 85 |
+
height: 1024
|
| 86 |
+
prompts:
|
| 87 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 88 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 89 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 90 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 91 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 92 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 93 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 94 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 95 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 96 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 97 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 98 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 99 |
+
neg: "" # not used on flex
|
| 100 |
+
seed: 42
|
| 101 |
+
walk_seed: true
|
| 102 |
+
guidance_scale: 4
|
| 103 |
+
sample_steps: 25
|
| 104 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 105 |
+
meta:
|
| 106 |
+
name: "[name]"
|
| 107 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_full_fine_tune_lumina.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# This configuration requires 24GB of VRAM or more to operate
|
| 3 |
+
job: extension
|
| 4 |
+
config:
|
| 5 |
+
# this name will be the folder and filename name
|
| 6 |
+
name: "my_first_lumina_finetune_v1"
|
| 7 |
+
process:
|
| 8 |
+
- type: 'sd_trainer'
|
| 9 |
+
# root folder to save training sessions/samples/weights
|
| 10 |
+
training_folder: "output"
|
| 11 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 12 |
+
# performance_log_every: 1000
|
| 13 |
+
device: cuda:0
|
| 14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
save:
|
| 18 |
+
dtype: bf16 # precision to save
|
| 19 |
+
save_every: 250 # save every this many steps
|
| 20 |
+
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
| 21 |
+
save_format: 'diffusers' # 'diffusers'
|
| 22 |
+
datasets:
|
| 23 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 24 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 25 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 26 |
+
# on windows, escape back slashes with another backslash so
|
| 27 |
+
# "C:\\path\\to\\images\\folder"
|
| 28 |
+
- folder_path: "/path/to/images/folder"
|
| 29 |
+
caption_ext: "txt"
|
| 30 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 31 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 32 |
+
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 33 |
+
resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
|
| 34 |
+
train:
|
| 35 |
+
batch_size: 1
|
| 36 |
+
|
| 37 |
+
# can be 'sigmoid', 'linear', or 'lumina2_shift'
|
| 38 |
+
timestep_type: 'lumina2_shift'
|
| 39 |
+
|
| 40 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 41 |
+
gradient_accumulation: 1
|
| 42 |
+
train_unet: true
|
| 43 |
+
train_text_encoder: false # probably won't work with lumina2
|
| 44 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 45 |
+
noise_scheduler: "flowmatch" # for training only
|
| 46 |
+
optimizer: "adafactor"
|
| 47 |
+
lr: 3e-5
|
| 48 |
+
|
| 49 |
+
# Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
|
| 50 |
+
# 0.1 is 10% of paramiters active at easc step. Only works with adafactor
|
| 51 |
+
|
| 52 |
+
# do_paramiter_swapping: true
|
| 53 |
+
# paramiter_swapping_factor: 0.9
|
| 54 |
+
|
| 55 |
+
# uncomment this to skip the pre training sample
|
| 56 |
+
# skip_first_sample: true
|
| 57 |
+
# uncomment to completely disable sampling
|
| 58 |
+
# disable_sampling: true
|
| 59 |
+
|
| 60 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
| 61 |
+
# ema_config:
|
| 62 |
+
# use_ema: true
|
| 63 |
+
# ema_decay: 0.99
|
| 64 |
+
|
| 65 |
+
# will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
|
| 66 |
+
dtype: bf16
|
| 67 |
+
model:
|
| 68 |
+
# huggingface model name or path
|
| 69 |
+
name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
|
| 70 |
+
is_lumina2: true # lumina2 architecture
|
| 71 |
+
# you can quantize just the Gemma2 text encoder here to save vram
|
| 72 |
+
quantize_te: true
|
| 73 |
+
sample:
|
| 74 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 75 |
+
sample_every: 250 # sample every this many steps
|
| 76 |
+
width: 1024
|
| 77 |
+
height: 1024
|
| 78 |
+
prompts:
|
| 79 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 80 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 81 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 82 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 83 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 84 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 85 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 86 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 87 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 88 |
+
- "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
|
| 89 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 90 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 91 |
+
neg: ""
|
| 92 |
+
seed: 42
|
| 93 |
+
walk_seed: true
|
| 94 |
+
guidance_scale: 4.0
|
| 95 |
+
sample_steps: 25
|
| 96 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 97 |
+
meta:
|
| 98 |
+
name: "[name]"
|
| 99 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_chroma_24gb.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_chroma_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
- folder_path: "/path/to/images/folder"
|
| 35 |
+
caption_ext: "txt"
|
| 36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 39 |
+
resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
|
| 40 |
+
train:
|
| 41 |
+
batch_size: 1
|
| 42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 43 |
+
gradient_accumulation: 1
|
| 44 |
+
train_unet: true
|
| 45 |
+
train_text_encoder: false # probably won't work with chroma
|
| 46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 47 |
+
noise_scheduler: "flowmatch" # for training only
|
| 48 |
+
optimizer: "adamw8bit"
|
| 49 |
+
lr: 1e-4
|
| 50 |
+
# uncomment this to skip the pre training sample
|
| 51 |
+
# skip_first_sample: true
|
| 52 |
+
# uncomment to completely disable sampling
|
| 53 |
+
# disable_sampling: true
|
| 54 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 55 |
+
# linear_timesteps: true
|
| 56 |
+
|
| 57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 58 |
+
ema_config:
|
| 59 |
+
use_ema: true
|
| 60 |
+
ema_decay: 0.99
|
| 61 |
+
|
| 62 |
+
# will probably need this if gpu supports it for chroma, other dtypes may not work correctly
|
| 63 |
+
dtype: bf16
|
| 64 |
+
model:
|
| 65 |
+
# Download the whichever model you prefer from the Chroma repo
|
| 66 |
+
# https://huggingface.co/lodestones/Chroma/tree/main
|
| 67 |
+
# point to it here.
|
| 68 |
+
# name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
|
| 69 |
+
|
| 70 |
+
# using lodestones/Chroma will automatically use the latest version
|
| 71 |
+
name_or_path: "lodestones/Chroma"
|
| 72 |
+
|
| 73 |
+
# # You can also select a version of Chroma like so
|
| 74 |
+
# name_or_path: "lodestones/Chroma/v28"
|
| 75 |
+
|
| 76 |
+
arch: "chroma"
|
| 77 |
+
quantize: true # run 8bit mixed precision
|
| 78 |
+
sample:
|
| 79 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 80 |
+
sample_every: 250 # sample every this many steps
|
| 81 |
+
width: 1024
|
| 82 |
+
height: 1024
|
| 83 |
+
prompts:
|
| 84 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 85 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 86 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 87 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 88 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 89 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 90 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 91 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 92 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 93 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 94 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 95 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 96 |
+
neg: "" # negative prompt, optional
|
| 97 |
+
seed: 42
|
| 98 |
+
walk_seed: true
|
| 99 |
+
guidance_scale: 4
|
| 100 |
+
sample_steps: 25
|
| 101 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 102 |
+
meta:
|
| 103 |
+
name: "[name]"
|
| 104 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_flex2_24gb.yaml
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Note, Flex2 is a highly experimental WIP model. Finetuning a model with built in controls and inpainting has not
|
| 2 |
+
# been done before, so you will be experimenting with me on how to do it. This is my recommended setup, but this is highly
|
| 3 |
+
# subject to change as we learn more about how Flex2 works.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
job: extension
|
| 7 |
+
config:
|
| 8 |
+
# this name will be the folder and filename name
|
| 9 |
+
name: "my_first_flex2_lora_v1"
|
| 10 |
+
process:
|
| 11 |
+
- type: 'sd_trainer'
|
| 12 |
+
# root folder to save training sessions/samples/weights
|
| 13 |
+
training_folder: "output"
|
| 14 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 15 |
+
# performance_log_every: 1000
|
| 16 |
+
device: cuda:0
|
| 17 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 18 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 19 |
+
# trigger_word: "p3r5on"
|
| 20 |
+
network:
|
| 21 |
+
type: "lora"
|
| 22 |
+
linear: 32
|
| 23 |
+
linear_alpha: 32
|
| 24 |
+
save:
|
| 25 |
+
dtype: float16 # precision to save
|
| 26 |
+
save_every: 250 # save every this many steps
|
| 27 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 28 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 29 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 30 |
+
# hf_repo_id: your-username/your-model-slug
|
| 31 |
+
# hf_private: true #whether the repo is private or public
|
| 32 |
+
datasets:
|
| 33 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 34 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 35 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 36 |
+
# on windows, escape back slashes with another backslash so
|
| 37 |
+
# "C:\\path\\to\\images\\folder"
|
| 38 |
+
- folder_path: "/path/to/images/folder"
|
| 39 |
+
# Flex2 is trained with controls and inpainting. If you want the model to truely understand how the
|
| 40 |
+
# controls function with your dataset, it is a good idea to keep doing controls during training.
|
| 41 |
+
# this will automatically generate the controls for you before training. The current script is not
|
| 42 |
+
# fully optimized so this could be rather slow for large datasets, but it caches them to disk so it
|
| 43 |
+
# only needs to be done once. If you want to skip this step, you can set the controls to [] and it will
|
| 44 |
+
controls:
|
| 45 |
+
- "depth"
|
| 46 |
+
- "line"
|
| 47 |
+
- "pose"
|
| 48 |
+
- "inpaint"
|
| 49 |
+
|
| 50 |
+
# you can make custom inpainting images as well. These images must be webp or png format with an alpha.
|
| 51 |
+
# just erase the part of the image you want to inpaint and save it as a webp or png. Again, erase your
|
| 52 |
+
# train target. So the person if training a person. The automatic controls above with inpaint will
|
| 53 |
+
# just run a background remover mask and erase the foreground, which works well for subjects.
|
| 54 |
+
|
| 55 |
+
# inpaint_path: "/my/impaint/images"
|
| 56 |
+
|
| 57 |
+
# you can also specify existing control image pairs. It can handle multiple groups and will randomly
|
| 58 |
+
# select one for each step.
|
| 59 |
+
|
| 60 |
+
# control_path:
|
| 61 |
+
# - "/my/custom/control/images"
|
| 62 |
+
# - "/my/custom/control/images2"
|
| 63 |
+
|
| 64 |
+
caption_ext: "txt"
|
| 65 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 66 |
+
resolution: [ 512, 768, 1024 ] # flex2 enjoys multiple resolutions
|
| 67 |
+
train:
|
| 68 |
+
batch_size: 1
|
| 69 |
+
# IMPORTANT! For Flex2, you must bypass the guidance embedder during training
|
| 70 |
+
bypass_guidance_embedding: true
|
| 71 |
+
|
| 72 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
| 73 |
+
gradient_accumulation: 1
|
| 74 |
+
train_unet: true
|
| 75 |
+
train_text_encoder: false # probably won't work with flex2
|
| 76 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 77 |
+
noise_scheduler: "flowmatch" # for training only
|
| 78 |
+
# shift works well for training fast and learning composition and style.
|
| 79 |
+
# for just subject, you may want to change this to sigmoid
|
| 80 |
+
timestep_type: 'shift' # 'linear', 'sigmoid', 'shift'
|
| 81 |
+
optimizer: "adamw8bit"
|
| 82 |
+
lr: 1e-4
|
| 83 |
+
|
| 84 |
+
optimizer_params:
|
| 85 |
+
weight_decay: 1e-5
|
| 86 |
+
# uncomment this to skip the pre training sample
|
| 87 |
+
# skip_first_sample: true
|
| 88 |
+
# uncomment to completely disable sampling
|
| 89 |
+
# disable_sampling: true
|
| 90 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 91 |
+
# linear_timesteps: true
|
| 92 |
+
|
| 93 |
+
# ema will smooth out learning, but could slow it down. Defaults off
|
| 94 |
+
ema_config:
|
| 95 |
+
use_ema: false
|
| 96 |
+
ema_decay: 0.99
|
| 97 |
+
|
| 98 |
+
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
| 99 |
+
dtype: bf16
|
| 100 |
+
model:
|
| 101 |
+
# huggingface model name or path
|
| 102 |
+
name_or_path: "ostris/Flex.2-preview"
|
| 103 |
+
arch: "flex2"
|
| 104 |
+
quantize: true # run 8bit mixed precision
|
| 105 |
+
quantize_te: true
|
| 106 |
+
|
| 107 |
+
# you can pass special training infor for controls to the model here
|
| 108 |
+
# percentages are decimal based so 0.0 is 0% and 1.0 is 100% of the time.
|
| 109 |
+
model_kwargs:
|
| 110 |
+
# inverts the inpainting mask, good to learn outpainting as well, recommended 0.0 for characters
|
| 111 |
+
invert_inpaint_mask_chance: 0.5
|
| 112 |
+
# this will do a normal t2i training step without inpaint when dropped out. REcommended if you want
|
| 113 |
+
# your lora to be able to inference with and without inpainting.
|
| 114 |
+
inpaint_dropout: 0.5
|
| 115 |
+
# randomly drops out the control image. Dropout recvommended if your want it to work without controls as well.
|
| 116 |
+
control_dropout: 0.5
|
| 117 |
+
# does a random inpaint blob. Usually a good idea to keep. Without it, the model will learn to always 100%
|
| 118 |
+
# fill the inpaint area with your subject. This is not always a good thing.
|
| 119 |
+
inpaint_random_chance: 0.5
|
| 120 |
+
# generates random inpaint blobs if you did not provide an inpaint image for your dataset. Inpaint breaks down fast
|
| 121 |
+
# if you are not training with it. Controls are a little more robust and can be left out,
|
| 122 |
+
# but when in doubt, always leave this on
|
| 123 |
+
do_random_inpainting: false
|
| 124 |
+
# does random blurring of the inpaint mask. Helps prevent weird edge artifacts for real workd inpainting. Leave on.
|
| 125 |
+
random_blur_mask: true
|
| 126 |
+
# applies a small amount of random dialition and restriction to the inpaint mask. Helps with edge artifacts.
|
| 127 |
+
# Leave on.
|
| 128 |
+
random_dialate_mask: true
|
| 129 |
+
sample:
|
| 130 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 131 |
+
sample_every: 250 # sample every this many steps
|
| 132 |
+
width: 1024
|
| 133 |
+
height: 1024
|
| 134 |
+
prompts:
|
| 135 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 136 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 137 |
+
|
| 138 |
+
# you can use a single inpaint or single control image on your samples.
|
| 139 |
+
# for controls, the ctrl_idx is 1, the images can be any name and image format.
|
| 140 |
+
# use either a pose/line/depth image or whatever you are training with. An example is
|
| 141 |
+
# - "photo of [trigger] --ctrl_idx 1 --ctrl_img /path/to/control/image.jpg"
|
| 142 |
+
|
| 143 |
+
# for an inpainting image, it must be png/webp. Erase the part of the image you want to inpaint
|
| 144 |
+
# IMPORTANT! the inpaint images must be ctrl_idx 0 and have .inpaint.{ext} in the name for this to work right.
|
| 145 |
+
# - "photo of [trigger] --ctrl_idx 0 --ctrl_img /path/to/inpaint/image.inpaint.png"
|
| 146 |
+
|
| 147 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 148 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 149 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 150 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 151 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 152 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 153 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 154 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 155 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 156 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 157 |
+
neg: "" # not used on flex2
|
| 158 |
+
seed: 42
|
| 159 |
+
walk_seed: true
|
| 160 |
+
guidance_scale: 4
|
| 161 |
+
sample_steps: 25
|
| 162 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 163 |
+
meta:
|
| 164 |
+
name: "[name]"
|
| 165 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_flex_24gb.yaml
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flex_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
- folder_path: "/path/to/images/folder"
|
| 35 |
+
caption_ext: "txt"
|
| 36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 39 |
+
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
| 40 |
+
train:
|
| 41 |
+
batch_size: 1
|
| 42 |
+
# IMPORTANT! For Flex, you must bypass the guidance embedder during training
|
| 43 |
+
bypass_guidance_embedding: true
|
| 44 |
+
|
| 45 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 46 |
+
gradient_accumulation: 1
|
| 47 |
+
train_unet: true
|
| 48 |
+
train_text_encoder: false # probably won't work with flex
|
| 49 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 50 |
+
noise_scheduler: "flowmatch" # for training only
|
| 51 |
+
optimizer: "adamw8bit"
|
| 52 |
+
lr: 1e-4
|
| 53 |
+
# uncomment this to skip the pre training sample
|
| 54 |
+
# skip_first_sample: true
|
| 55 |
+
# uncomment to completely disable sampling
|
| 56 |
+
# disable_sampling: true
|
| 57 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 58 |
+
# linear_timesteps: true
|
| 59 |
+
|
| 60 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 61 |
+
ema_config:
|
| 62 |
+
use_ema: true
|
| 63 |
+
ema_decay: 0.99
|
| 64 |
+
|
| 65 |
+
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
| 66 |
+
dtype: bf16
|
| 67 |
+
model:
|
| 68 |
+
# huggingface model name or path
|
| 69 |
+
name_or_path: "ostris/Flex.1-alpha"
|
| 70 |
+
is_flux: true
|
| 71 |
+
quantize: true # run 8bit mixed precision
|
| 72 |
+
quantize_kwargs:
|
| 73 |
+
exclude:
|
| 74 |
+
- "*time_text_embed*" # exclude the time text embedder from quantization
|
| 75 |
+
sample:
|
| 76 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 77 |
+
sample_every: 250 # sample every this many steps
|
| 78 |
+
width: 1024
|
| 79 |
+
height: 1024
|
| 80 |
+
prompts:
|
| 81 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 82 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 83 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 84 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 85 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 86 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 87 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 88 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 89 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 90 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 91 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 92 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 93 |
+
neg: "" # not used on flex
|
| 94 |
+
seed: 42
|
| 95 |
+
walk_seed: true
|
| 96 |
+
guidance_scale: 4
|
| 97 |
+
sample_steps: 25
|
| 98 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 99 |
+
meta:
|
| 100 |
+
name: "[name]"
|
| 101 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_flux_24gb.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flux_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
- folder_path: "/path/to/images/folder"
|
| 35 |
+
caption_ext: "txt"
|
| 36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 39 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
| 40 |
+
train:
|
| 41 |
+
batch_size: 1
|
| 42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 43 |
+
gradient_accumulation_steps: 1
|
| 44 |
+
train_unet: true
|
| 45 |
+
train_text_encoder: false # probably won't work with flux
|
| 46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 47 |
+
noise_scheduler: "flowmatch" # for training only
|
| 48 |
+
optimizer: "adamw8bit"
|
| 49 |
+
lr: 1e-4
|
| 50 |
+
# uncomment this to skip the pre training sample
|
| 51 |
+
# skip_first_sample: true
|
| 52 |
+
# uncomment to completely disable sampling
|
| 53 |
+
# disable_sampling: true
|
| 54 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 55 |
+
# linear_timesteps: true
|
| 56 |
+
|
| 57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 58 |
+
ema_config:
|
| 59 |
+
use_ema: true
|
| 60 |
+
ema_decay: 0.99
|
| 61 |
+
|
| 62 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
| 63 |
+
dtype: bf16
|
| 64 |
+
model:
|
| 65 |
+
# huggingface model name or path
|
| 66 |
+
name_or_path: "black-forest-labs/FLUX.1-dev"
|
| 67 |
+
is_flux: true
|
| 68 |
+
quantize: true # run 8bit mixed precision
|
| 69 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
| 70 |
+
sample:
|
| 71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 72 |
+
sample_every: 250 # sample every this many steps
|
| 73 |
+
width: 1024
|
| 74 |
+
height: 1024
|
| 75 |
+
prompts:
|
| 76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 85 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 88 |
+
neg: "" # not used on flux
|
| 89 |
+
seed: 42
|
| 90 |
+
walk_seed: true
|
| 91 |
+
guidance_scale: 4
|
| 92 |
+
sample_steps: 20
|
| 93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 94 |
+
meta:
|
| 95 |
+
name: "[name]"
|
| 96 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_flux_kontext_24gb.yaml
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flux_kontext_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
- folder_path: "/path/to/images/folder"
|
| 35 |
+
# control path is the input images for kontext for a paired dataset. These are the source images you want to change.
|
| 36 |
+
# You can comment this out and only use normal images if you don't have a paired dataset.
|
| 37 |
+
# Control images need to match the filenames on the folder path but in
|
| 38 |
+
# a different folder. These do not need captions.
|
| 39 |
+
control_path: "/path/to/control/folder"
|
| 40 |
+
caption_ext: "txt"
|
| 41 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 42 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 43 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 44 |
+
# Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram.
|
| 45 |
+
resolution: [ 512, 768 ] # flux enjoys multiple resolutions
|
| 46 |
+
# resolution: [ 512, 768, 1024 ]
|
| 47 |
+
train:
|
| 48 |
+
batch_size: 1
|
| 49 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
| 50 |
+
gradient_accumulation_steps: 1
|
| 51 |
+
train_unet: true
|
| 52 |
+
train_text_encoder: false # probably won't work with flux
|
| 53 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 54 |
+
noise_scheduler: "flowmatch" # for training only
|
| 55 |
+
optimizer: "adamw8bit"
|
| 56 |
+
lr: 1e-4
|
| 57 |
+
timestep_type: "weighted" # sigmoid, linear, or weighted.
|
| 58 |
+
# uncomment this to skip the pre training sample
|
| 59 |
+
# skip_first_sample: true
|
| 60 |
+
# uncomment to completely disable sampling
|
| 61 |
+
# disable_sampling: true
|
| 62 |
+
|
| 63 |
+
# ema will smooth out learning, but could slow it down.
|
| 64 |
+
|
| 65 |
+
# ema_config:
|
| 66 |
+
# use_ema: true
|
| 67 |
+
# ema_decay: 0.99
|
| 68 |
+
|
| 69 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
| 70 |
+
dtype: bf16
|
| 71 |
+
model:
|
| 72 |
+
# huggingface model name or path. This model is gated.
|
| 73 |
+
# visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions
|
| 74 |
+
# and then you can use this model.
|
| 75 |
+
name_or_path: "black-forest-labs/FLUX.1-Kontext-dev"
|
| 76 |
+
arch: "flux_kontext"
|
| 77 |
+
quantize: true # run 8bit mixed precision
|
| 78 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
| 79 |
+
sample:
|
| 80 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 81 |
+
sample_every: 250 # sample every this many steps
|
| 82 |
+
width: 1024
|
| 83 |
+
height: 1024
|
| 84 |
+
prompts:
|
| 85 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 86 |
+
# the --ctrl_img path is the one loaded to apply the kontext editing to
|
| 87 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 88 |
+
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
|
| 89 |
+
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
|
| 90 |
+
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
|
| 91 |
+
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
|
| 92 |
+
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
|
| 93 |
+
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
|
| 94 |
+
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
|
| 95 |
+
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
|
| 96 |
+
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
|
| 97 |
+
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
|
| 98 |
+
neg: "" # not used on flux
|
| 99 |
+
seed: 42
|
| 100 |
+
walk_seed: true
|
| 101 |
+
guidance_scale: 4
|
| 102 |
+
sample_steps: 20
|
| 103 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 104 |
+
meta:
|
| 105 |
+
name: "[name]"
|
| 106 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_flux_schnell_24gb.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_flux_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
- folder_path: "/path/to/images/folder"
|
| 35 |
+
caption_ext: "txt"
|
| 36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 39 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
| 40 |
+
train:
|
| 41 |
+
batch_size: 1
|
| 42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 43 |
+
gradient_accumulation_steps: 1
|
| 44 |
+
train_unet: true
|
| 45 |
+
train_text_encoder: false # probably won't work with flux
|
| 46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 47 |
+
noise_scheduler: "flowmatch" # for training only
|
| 48 |
+
optimizer: "adamw8bit"
|
| 49 |
+
lr: 1e-4
|
| 50 |
+
# uncomment this to skip the pre training sample
|
| 51 |
+
# skip_first_sample: true
|
| 52 |
+
# uncomment to completely disable sampling
|
| 53 |
+
# disable_sampling: true
|
| 54 |
+
# uncomment to use new bell curved weighting. Experimental but may produce better results
|
| 55 |
+
# linear_timesteps: true
|
| 56 |
+
|
| 57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 58 |
+
ema_config:
|
| 59 |
+
use_ema: true
|
| 60 |
+
ema_decay: 0.99
|
| 61 |
+
|
| 62 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
| 63 |
+
dtype: bf16
|
| 64 |
+
model:
|
| 65 |
+
# huggingface model name or path
|
| 66 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
| 67 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
|
| 68 |
+
is_flux: true
|
| 69 |
+
quantize: true # run 8bit mixed precision
|
| 70 |
+
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
|
| 71 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
| 72 |
+
sample:
|
| 73 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 74 |
+
sample_every: 250 # sample every this many steps
|
| 75 |
+
width: 1024
|
| 76 |
+
height: 1024
|
| 77 |
+
prompts:
|
| 78 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 79 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 80 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 81 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 82 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 83 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 84 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 85 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 86 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 87 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 88 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 89 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 90 |
+
neg: "" # not used on flux
|
| 91 |
+
seed: 42
|
| 92 |
+
walk_seed: true
|
| 93 |
+
guidance_scale: 1 # schnell does not do guidance
|
| 94 |
+
sample_steps: 4 # 1 - 4 works well
|
| 95 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 96 |
+
meta:
|
| 97 |
+
name: "[name]"
|
| 98 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_hidream_48.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
|
| 2 |
+
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
|
| 3 |
+
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
|
| 4 |
+
# HiDream has a mixture of experts that may take special training considerations that I do not
|
| 5 |
+
# have implemented properly. The current implementation seems to work well for LoRA training, but
|
| 6 |
+
# may not be effective for longer training runs. The implementation could change in future updates
|
| 7 |
+
# so your results may vary when this happens.
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
job: extension
|
| 11 |
+
config:
|
| 12 |
+
# this name will be the folder and filename name
|
| 13 |
+
name: "my_first_hidream_lora_v1"
|
| 14 |
+
process:
|
| 15 |
+
- type: 'sd_trainer'
|
| 16 |
+
# root folder to save training sessions/samples/weights
|
| 17 |
+
training_folder: "output"
|
| 18 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 19 |
+
# performance_log_every: 1000
|
| 20 |
+
device: cuda:0
|
| 21 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 22 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 23 |
+
# trigger_word: "p3r5on"
|
| 24 |
+
network:
|
| 25 |
+
type: "lora"
|
| 26 |
+
linear: 32
|
| 27 |
+
linear_alpha: 32
|
| 28 |
+
network_kwargs:
|
| 29 |
+
# it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
|
| 30 |
+
# proper training of it is not fully implemented
|
| 31 |
+
ignore_if_contains:
|
| 32 |
+
- "ff_i.experts"
|
| 33 |
+
- "ff_i.gate"
|
| 34 |
+
save:
|
| 35 |
+
dtype: bfloat16 # precision to save
|
| 36 |
+
save_every: 250 # save every this many steps
|
| 37 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 38 |
+
datasets:
|
| 39 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 40 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 41 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 42 |
+
# on windows, escape back slashes with another backslash so
|
| 43 |
+
# "C:\\path\\to\\images\\folder"
|
| 44 |
+
- folder_path: "/path/to/images/folder"
|
| 45 |
+
caption_ext: "txt"
|
| 46 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 47 |
+
resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
|
| 48 |
+
train:
|
| 49 |
+
batch_size: 1
|
| 50 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
| 51 |
+
gradient_accumulation_steps: 1
|
| 52 |
+
train_unet: true
|
| 53 |
+
train_text_encoder: false # wont work with hidream
|
| 54 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 55 |
+
noise_scheduler: "flowmatch" # for training only
|
| 56 |
+
timestep_type: shift # sigmoid, shift, linear
|
| 57 |
+
optimizer: "adamw8bit"
|
| 58 |
+
lr: 2e-4
|
| 59 |
+
# uncomment this to skip the pre training sample
|
| 60 |
+
# skip_first_sample: true
|
| 61 |
+
# uncomment to completely disable sampling
|
| 62 |
+
# disable_sampling: true
|
| 63 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 64 |
+
# linear_timesteps: true
|
| 65 |
+
|
| 66 |
+
# ema will smooth out learning, but could slow it down. Defaults off
|
| 67 |
+
ema_config:
|
| 68 |
+
use_ema: false
|
| 69 |
+
ema_decay: 0.99
|
| 70 |
+
|
| 71 |
+
# will probably need this if gpu supports it for hidream, other dtypes may not work correctly
|
| 72 |
+
dtype: bf16
|
| 73 |
+
model:
|
| 74 |
+
# the transformer will get grabbed from this hf repo
|
| 75 |
+
# warning ONLY train on Full. The dev and fast models are distilled and will break
|
| 76 |
+
name_or_path: "HiDream-ai/HiDream-I1-Full"
|
| 77 |
+
# the extras will be grabbed from this hf repo. (text encoder, vae)
|
| 78 |
+
extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
|
| 79 |
+
arch: "hidream"
|
| 80 |
+
# both need to be quantized to train on 48GB currently
|
| 81 |
+
quantize: true
|
| 82 |
+
quantize_te: true
|
| 83 |
+
model_kwargs:
|
| 84 |
+
# llama is a gated model, It defaults to unsloth version, but you can set the llama path here
|
| 85 |
+
llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
|
| 86 |
+
sample:
|
| 87 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 88 |
+
sample_every: 250 # sample every this many steps
|
| 89 |
+
width: 1024
|
| 90 |
+
height: 1024
|
| 91 |
+
prompts:
|
| 92 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 93 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 94 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 95 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 96 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 97 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 98 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 99 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 100 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 101 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 102 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 103 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 104 |
+
neg: ""
|
| 105 |
+
seed: 42
|
| 106 |
+
walk_seed: true
|
| 107 |
+
guidance_scale: 4
|
| 108 |
+
sample_steps: 25
|
| 109 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 110 |
+
meta:
|
| 111 |
+
name: "[name]"
|
| 112 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_lumina.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# This configuration requires 20GB of VRAM or more to operate
|
| 3 |
+
job: extension
|
| 4 |
+
config:
|
| 5 |
+
# this name will be the folder and filename name
|
| 6 |
+
name: "my_first_lumina_lora_v1"
|
| 7 |
+
process:
|
| 8 |
+
- type: 'sd_trainer'
|
| 9 |
+
# root folder to save training sessions/samples/weights
|
| 10 |
+
training_folder: "output"
|
| 11 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 12 |
+
# performance_log_every: 1000
|
| 13 |
+
device: cuda:0
|
| 14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
network:
|
| 18 |
+
type: "lora"
|
| 19 |
+
linear: 16
|
| 20 |
+
linear_alpha: 16
|
| 21 |
+
save:
|
| 22 |
+
dtype: bf16 # precision to save
|
| 23 |
+
save_every: 250 # save every this many steps
|
| 24 |
+
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
| 25 |
+
save_format: 'diffusers' # 'diffusers'
|
| 26 |
+
datasets:
|
| 27 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 28 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 29 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 30 |
+
# on windows, escape back slashes with another backslash so
|
| 31 |
+
# "C:\\path\\to\\images\\folder"
|
| 32 |
+
- folder_path: "/path/to/images/folder"
|
| 33 |
+
caption_ext: "txt"
|
| 34 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 35 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 36 |
+
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 37 |
+
resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
|
| 38 |
+
train:
|
| 39 |
+
batch_size: 1
|
| 40 |
+
|
| 41 |
+
# can be 'sigmoid', 'linear', or 'lumina2_shift'
|
| 42 |
+
timestep_type: 'lumina2_shift'
|
| 43 |
+
|
| 44 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 45 |
+
gradient_accumulation: 1
|
| 46 |
+
train_unet: true
|
| 47 |
+
train_text_encoder: false # probably won't work with lumina2
|
| 48 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 49 |
+
noise_scheduler: "flowmatch" # for training only
|
| 50 |
+
optimizer: "adamw8bit"
|
| 51 |
+
lr: 1e-4
|
| 52 |
+
# uncomment this to skip the pre training sample
|
| 53 |
+
# skip_first_sample: true
|
| 54 |
+
# uncomment to completely disable sampling
|
| 55 |
+
# disable_sampling: true
|
| 56 |
+
|
| 57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
| 58 |
+
ema_config:
|
| 59 |
+
use_ema: true
|
| 60 |
+
ema_decay: 0.99
|
| 61 |
+
|
| 62 |
+
# will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
|
| 63 |
+
dtype: bf16
|
| 64 |
+
model:
|
| 65 |
+
# huggingface model name or path
|
| 66 |
+
name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
|
| 67 |
+
is_lumina2: true # lumina2 architecture
|
| 68 |
+
# you can quantize just the Gemma2 text encoder here to save vram
|
| 69 |
+
quantize_te: true
|
| 70 |
+
sample:
|
| 71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 72 |
+
sample_every: 250 # sample every this many steps
|
| 73 |
+
width: 1024
|
| 74 |
+
height: 1024
|
| 75 |
+
prompts:
|
| 76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 85 |
+
- "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
|
| 86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 88 |
+
neg: ""
|
| 89 |
+
seed: 42
|
| 90 |
+
walk_seed: true
|
| 91 |
+
guidance_scale: 4.0
|
| 92 |
+
sample_steps: 25
|
| 93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 94 |
+
meta:
|
| 95 |
+
name: "[name]"
|
| 96 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_omnigen2_24gb.yaml
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_omnigen2_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 16
|
| 19 |
+
linear_alpha: 16
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
- folder_path: "/path/to/images/folder"
|
| 35 |
+
caption_ext: "txt"
|
| 36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 39 |
+
resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
|
| 40 |
+
train:
|
| 41 |
+
batch_size: 1
|
| 42 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
| 43 |
+
gradient_accumulation: 1
|
| 44 |
+
train_unet: true
|
| 45 |
+
train_text_encoder: false # probably won't work with omnigen2
|
| 46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 47 |
+
noise_scheduler: "flowmatch" # for training only
|
| 48 |
+
optimizer: "adamw8bit"
|
| 49 |
+
lr: 1e-4
|
| 50 |
+
timestep_type: 'sigmoid' # sigmoid, linear, shift
|
| 51 |
+
# uncomment this to skip the pre training sample
|
| 52 |
+
# skip_first_sample: true
|
| 53 |
+
# uncomment to completely disable sampling
|
| 54 |
+
# disable_sampling: true
|
| 55 |
+
|
| 56 |
+
# ema will smooth out learning, but could slow it down.
|
| 57 |
+
# ema_config:
|
| 58 |
+
# use_ema: true
|
| 59 |
+
# ema_decay: 0.99
|
| 60 |
+
|
| 61 |
+
# will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
|
| 62 |
+
dtype: bf16
|
| 63 |
+
model:
|
| 64 |
+
name_or_path: "OmniGen2/OmniGen2
|
| 65 |
+
arch: "omnigen2"
|
| 66 |
+
quantize_te: true # quantize_only te
|
| 67 |
+
# quantize: true # quantize transformer
|
| 68 |
+
sample:
|
| 69 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 70 |
+
sample_every: 250 # sample every this many steps
|
| 71 |
+
width: 1024
|
| 72 |
+
height: 1024
|
| 73 |
+
prompts:
|
| 74 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 75 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 76 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 77 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 78 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 79 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 80 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 81 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 82 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 83 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 84 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 85 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 86 |
+
neg: "" # negative prompt, optional
|
| 87 |
+
seed: 42
|
| 88 |
+
walk_seed: true
|
| 89 |
+
guidance_scale: 4
|
| 90 |
+
sample_steps: 25
|
| 91 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 92 |
+
meta:
|
| 93 |
+
name: "[name]"
|
| 94 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_qwen_image_24gb.yaml
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_qwen_image_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# Trigger words will not work when caching text embeddings
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
network:
|
| 18 |
+
type: "lora"
|
| 19 |
+
linear: 16
|
| 20 |
+
linear_alpha: 16
|
| 21 |
+
save:
|
| 22 |
+
dtype: float16 # precision to save
|
| 23 |
+
save_every: 250 # save every this many steps
|
| 24 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 25 |
+
datasets:
|
| 26 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 27 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 28 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 29 |
+
# on windows, escape back slashes with another backslash so
|
| 30 |
+
# "C:\\path\\to\\images\\folder"
|
| 31 |
+
- folder_path: "/path/to/images/folder"
|
| 32 |
+
caption_ext: "txt"
|
| 33 |
+
# default_caption: "a person" # if caching text embeddings, if you dont have captions, this will get cached
|
| 34 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 35 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 36 |
+
cache_latents_to_disk: true # leave this true unless you have a large dataset
|
| 37 |
+
# if you OOM, 1024 may be too much, but should work
|
| 38 |
+
resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
|
| 39 |
+
train:
|
| 40 |
+
batch_size: 1
|
| 41 |
+
# caching text embeddings is required for 24GB
|
| 42 |
+
cache_text_embeddings: true
|
| 43 |
+
|
| 44 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 45 |
+
gradient_accumulation: 1
|
| 46 |
+
train_unet: true
|
| 47 |
+
train_text_encoder: false # probably won't work with qwen image
|
| 48 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 49 |
+
noise_scheduler: "flowmatch" # for training only
|
| 50 |
+
optimizer: "adamw8bit"
|
| 51 |
+
lr: 1e-4
|
| 52 |
+
# uncomment this to skip the pre training sample
|
| 53 |
+
# skip_first_sample: true
|
| 54 |
+
# uncomment to completely disable sampling
|
| 55 |
+
# disable_sampling: true
|
| 56 |
+
dtype: bf16
|
| 57 |
+
model:
|
| 58 |
+
# huggingface model name or path
|
| 59 |
+
name_or_path: "Qwen/Qwen-Image"
|
| 60 |
+
arch: "qwen_image"
|
| 61 |
+
quantize: true
|
| 62 |
+
# qtype_te: "qfloat8" Default float8 qquantization
|
| 63 |
+
# to use the ARA use the | pipe to point to hf path, or a local path if you have one.
|
| 64 |
+
# 3bit is required for 24GB
|
| 65 |
+
qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors"
|
| 66 |
+
quantize_te: true
|
| 67 |
+
qtype_te: "qfloat8"
|
| 68 |
+
low_vram: true
|
| 69 |
+
sample:
|
| 70 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 71 |
+
sample_every: 250 # sample every this many steps
|
| 72 |
+
width: 1024
|
| 73 |
+
height: 1024
|
| 74 |
+
prompts:
|
| 75 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 76 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 77 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 78 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 79 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 80 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 81 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 82 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 83 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 84 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 85 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 86 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 87 |
+
neg: ""
|
| 88 |
+
seed: 42
|
| 89 |
+
walk_seed: true
|
| 90 |
+
guidance_scale: 3
|
| 91 |
+
sample_steps: 25
|
| 92 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 93 |
+
meta:
|
| 94 |
+
name: "[name]"
|
| 95 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_qwen_image_edit_2509_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'diffusion_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
network:
|
| 14 |
+
type: "lora"
|
| 15 |
+
linear: 16
|
| 16 |
+
linear_alpha: 16
|
| 17 |
+
save:
|
| 18 |
+
dtype: float16 # precision to save
|
| 19 |
+
save_every: 250 # save every this many steps
|
| 20 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 21 |
+
datasets:
|
| 22 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 23 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 24 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 25 |
+
# on windows, escape back slashes with another backslash so
|
| 26 |
+
# "C:\\path\\to\\images\\folder"
|
| 27 |
+
- folder_path: "/path/to/images/folder"
|
| 28 |
+
# can do up to 3 control image folders, file names must match target file names, but aspect/size can be different
|
| 29 |
+
control_path:
|
| 30 |
+
- "/path/to/control/images/folder1"
|
| 31 |
+
- "/path/to/control/images/folder2"
|
| 32 |
+
- "/path/to/control/images/folder3"
|
| 33 |
+
caption_ext: "txt"
|
| 34 |
+
# default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached
|
| 35 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 36 |
+
resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
|
| 37 |
+
# a trigger word that can be cached with the text embeddings
|
| 38 |
+
# trigger_word: "optional trigger word"
|
| 39 |
+
train:
|
| 40 |
+
batch_size: 1
|
| 41 |
+
# caching text embeddings is required for 32GB
|
| 42 |
+
cache_text_embeddings: true
|
| 43 |
+
# unload_text_encoder: true
|
| 44 |
+
|
| 45 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
| 46 |
+
gradient_accumulation: 1
|
| 47 |
+
timestep_type: "weighted"
|
| 48 |
+
train_unet: true
|
| 49 |
+
train_text_encoder: false # probably won't work with qwen image
|
| 50 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 51 |
+
noise_scheduler: "flowmatch" # for training only
|
| 52 |
+
optimizer: "adamw8bit"
|
| 53 |
+
lr: 1e-4
|
| 54 |
+
# uncomment this to skip the pre training sample
|
| 55 |
+
# skip_first_sample: true
|
| 56 |
+
# uncomment to completely disable sampling
|
| 57 |
+
# disable_sampling: true
|
| 58 |
+
dtype: bf16
|
| 59 |
+
model:
|
| 60 |
+
# huggingface model name or path
|
| 61 |
+
name_or_path: "Qwen/Qwen-Image-Edit-2509"
|
| 62 |
+
arch: "qwen_image_edit_plus"
|
| 63 |
+
quantize: true
|
| 64 |
+
# to use the ARA use the | pipe to point to hf path, or a local path if you have one.
|
| 65 |
+
# 3bit is required for 32GB
|
| 66 |
+
qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors"
|
| 67 |
+
quantize_te: true
|
| 68 |
+
qtype_te: "qfloat8"
|
| 69 |
+
low_vram: true
|
| 70 |
+
sample:
|
| 71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 72 |
+
sample_every: 250 # sample every this many steps
|
| 73 |
+
width: 1024
|
| 74 |
+
height: 1024
|
| 75 |
+
# you can provide up to 3 control images here
|
| 76 |
+
samples:
|
| 77 |
+
- prompt: "Do whatever with Image1 and Image2"
|
| 78 |
+
ctrl_img_1: "/path/to/image1.png"
|
| 79 |
+
ctrl_img_2: "/path/to/image2.png"
|
| 80 |
+
# ctrl_img_3: "/path/to/image3.png"
|
| 81 |
+
- prompt: "Do whatever with Image1 and Image2"
|
| 82 |
+
ctrl_img_1: "/path/to/image1.png"
|
| 83 |
+
ctrl_img_2: "/path/to/image2.png"
|
| 84 |
+
# ctrl_img_3: "/path/to/image3.png"
|
| 85 |
+
- prompt: "Do whatever with Image1 and Image2"
|
| 86 |
+
ctrl_img_1: "/path/to/image1.png"
|
| 87 |
+
ctrl_img_2: "/path/to/image2.png"
|
| 88 |
+
# ctrl_img_3: "/path/to/image3.png"
|
| 89 |
+
- prompt: "Do whatever with Image1 and Image2"
|
| 90 |
+
ctrl_img_1: "/path/to/image1.png"
|
| 91 |
+
ctrl_img_2: "/path/to/image2.png"
|
| 92 |
+
# ctrl_img_3: "/path/to/image3.png"
|
| 93 |
+
- prompt: "Do whatever with Image1 and Image2"
|
| 94 |
+
ctrl_img_1: "/path/to/image1.png"
|
| 95 |
+
ctrl_img_2: "/path/to/image2.png"
|
| 96 |
+
# ctrl_img_3: "/path/to/image3.png"
|
| 97 |
+
neg: ""
|
| 98 |
+
seed: 42
|
| 99 |
+
walk_seed: true
|
| 100 |
+
guidance_scale: 3
|
| 101 |
+
sample_steps: 25
|
| 102 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 103 |
+
meta:
|
| 104 |
+
name: "[name]"
|
| 105 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_qwen_image_edit_32gb.yaml
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_qwen_image_edit_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# Trigger words will not work when caching text embeddings
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
network:
|
| 18 |
+
type: "lora"
|
| 19 |
+
linear: 16
|
| 20 |
+
linear_alpha: 16
|
| 21 |
+
save:
|
| 22 |
+
dtype: float16 # precision to save
|
| 23 |
+
save_every: 250 # save every this many steps
|
| 24 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 25 |
+
datasets:
|
| 26 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 27 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 28 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 29 |
+
# on windows, escape back slashes with another backslash so
|
| 30 |
+
# "C:\\path\\to\\images\\folder"
|
| 31 |
+
- folder_path: "/path/to/images/folder"
|
| 32 |
+
control_path: "/path/to/control/images/folder"
|
| 33 |
+
caption_ext: "txt"
|
| 34 |
+
# default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached
|
| 35 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 36 |
+
resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
|
| 37 |
+
train:
|
| 38 |
+
batch_size: 1
|
| 39 |
+
# caching text embeddings is required for 32GB
|
| 40 |
+
cache_text_embeddings: true
|
| 41 |
+
|
| 42 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
| 43 |
+
gradient_accumulation: 1
|
| 44 |
+
timestep_type: "weighted"
|
| 45 |
+
train_unet: true
|
| 46 |
+
train_text_encoder: false # probably won't work with qwen image
|
| 47 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 48 |
+
noise_scheduler: "flowmatch" # for training only
|
| 49 |
+
optimizer: "adamw8bit"
|
| 50 |
+
lr: 1e-4
|
| 51 |
+
# uncomment this to skip the pre training sample
|
| 52 |
+
# skip_first_sample: true
|
| 53 |
+
# uncomment to completely disable sampling
|
| 54 |
+
# disable_sampling: true
|
| 55 |
+
dtype: bf16
|
| 56 |
+
model:
|
| 57 |
+
# huggingface model name or path
|
| 58 |
+
name_or_path: "Qwen/Qwen-Image-Edit"
|
| 59 |
+
arch: "qwen_image_edit"
|
| 60 |
+
quantize: true
|
| 61 |
+
# qtype_te: "qfloat8" Default float8 qquantization
|
| 62 |
+
# to use the ARA use the | pipe to point to hf path, or a local path if you have one.
|
| 63 |
+
# 3bit is required for 32GB
|
| 64 |
+
qtype: "uint3|qwen_image_edit_torchao_uint3.safetensors"
|
| 65 |
+
quantize_te: true
|
| 66 |
+
qtype_te: "qfloat8"
|
| 67 |
+
low_vram: true
|
| 68 |
+
sample:
|
| 69 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 70 |
+
sample_every: 250 # sample every this many steps
|
| 71 |
+
width: 1024
|
| 72 |
+
height: 1024
|
| 73 |
+
samples:
|
| 74 |
+
- prompt: "do the thing to it"
|
| 75 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 76 |
+
- prompt: "do the thing to it"
|
| 77 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 78 |
+
- prompt: "do the thing to it"
|
| 79 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 80 |
+
- prompt: "do the thing to it"
|
| 81 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 82 |
+
- prompt: "do the thing to it"
|
| 83 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 84 |
+
- prompt: "do the thing to it"
|
| 85 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 86 |
+
- prompt: "do the thing to it"
|
| 87 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 88 |
+
- prompt: "do the thing to it"
|
| 89 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 90 |
+
- prompt: "do the thing to it"
|
| 91 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 92 |
+
- prompt: "do the thing to it"
|
| 93 |
+
ctrl_img: "/path/to/control/image.jpg"
|
| 94 |
+
neg: ""
|
| 95 |
+
seed: 42
|
| 96 |
+
walk_seed: true
|
| 97 |
+
guidance_scale: 3
|
| 98 |
+
sample_steps: 25
|
| 99 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 100 |
+
meta:
|
| 101 |
+
name: "[name]"
|
| 102 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_sd35_large_24gb.yaml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
|
| 3 |
+
job: extension
|
| 4 |
+
config:
|
| 5 |
+
# this name will be the folder and filename name
|
| 6 |
+
name: "my_first_sd3l_lora_v1"
|
| 7 |
+
process:
|
| 8 |
+
- type: 'sd_trainer'
|
| 9 |
+
# root folder to save training sessions/samples/weights
|
| 10 |
+
training_folder: "output"
|
| 11 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 12 |
+
# performance_log_every: 1000
|
| 13 |
+
device: cuda:0
|
| 14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
network:
|
| 18 |
+
type: "lora"
|
| 19 |
+
linear: 16
|
| 20 |
+
linear_alpha: 16
|
| 21 |
+
save:
|
| 22 |
+
dtype: float16 # precision to save
|
| 23 |
+
save_every: 250 # save every this many steps
|
| 24 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 25 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 26 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 27 |
+
# hf_repo_id: your-username/your-model-slug
|
| 28 |
+
# hf_private: true #whether the repo is private or public
|
| 29 |
+
datasets:
|
| 30 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 31 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 32 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 33 |
+
# on windows, escape back slashes with another backslash so
|
| 34 |
+
# "C:\\path\\to\\images\\folder"
|
| 35 |
+
- folder_path: "/path/to/images/folder"
|
| 36 |
+
caption_ext: "txt"
|
| 37 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 38 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 39 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 40 |
+
resolution: [ 1024 ]
|
| 41 |
+
train:
|
| 42 |
+
batch_size: 1
|
| 43 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 44 |
+
gradient_accumulation_steps: 1
|
| 45 |
+
train_unet: true
|
| 46 |
+
train_text_encoder: false # May not fully work with SD3 yet
|
| 47 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 48 |
+
noise_scheduler: "flowmatch"
|
| 49 |
+
timestep_type: "linear" # linear or sigmoid
|
| 50 |
+
optimizer: "adamw8bit"
|
| 51 |
+
lr: 1e-4
|
| 52 |
+
# uncomment this to skip the pre training sample
|
| 53 |
+
# skip_first_sample: true
|
| 54 |
+
# uncomment to completely disable sampling
|
| 55 |
+
# disable_sampling: true
|
| 56 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
| 57 |
+
# linear_timesteps: true
|
| 58 |
+
|
| 59 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 60 |
+
ema_config:
|
| 61 |
+
use_ema: true
|
| 62 |
+
ema_decay: 0.99
|
| 63 |
+
|
| 64 |
+
# will probably need this if gpu supports it for sd3, other dtypes may not work correctly
|
| 65 |
+
dtype: bf16
|
| 66 |
+
model:
|
| 67 |
+
# huggingface model name or path
|
| 68 |
+
name_or_path: "stabilityai/stable-diffusion-3.5-large"
|
| 69 |
+
is_v3: true
|
| 70 |
+
quantize: true # run 8bit mixed precision
|
| 71 |
+
sample:
|
| 72 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
| 73 |
+
sample_every: 250 # sample every this many steps
|
| 74 |
+
width: 1024
|
| 75 |
+
height: 1024
|
| 76 |
+
prompts:
|
| 77 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 78 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 79 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 80 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 81 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 82 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 83 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 84 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 85 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 86 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 87 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 88 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 89 |
+
neg: ""
|
| 90 |
+
seed: 42
|
| 91 |
+
walk_seed: true
|
| 92 |
+
guidance_scale: 4
|
| 93 |
+
sample_steps: 25
|
| 94 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 95 |
+
meta:
|
| 96 |
+
name: "[name]"
|
| 97 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_wan21_14b_24gb.yaml
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot
|
| 2 |
+
# support keeping the text encoder on GPU while training with 24GB, so it is only good
|
| 3 |
+
# for training on a single prompt, for example a person with a trigger word.
|
| 4 |
+
# to train on captions, you need more vran for now.
|
| 5 |
+
---
|
| 6 |
+
job: extension
|
| 7 |
+
config:
|
| 8 |
+
# this name will be the folder and filename name
|
| 9 |
+
name: "my_first_wan21_14b_lora_v1"
|
| 10 |
+
process:
|
| 11 |
+
- type: 'sd_trainer'
|
| 12 |
+
# root folder to save training sessions/samples/weights
|
| 13 |
+
training_folder: "output"
|
| 14 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 15 |
+
# performance_log_every: 1000
|
| 16 |
+
device: cuda:0
|
| 17 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 18 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 19 |
+
# this is probably needed for 24GB cards when offloading TE to CPU
|
| 20 |
+
trigger_word: "p3r5on"
|
| 21 |
+
network:
|
| 22 |
+
type: "lora"
|
| 23 |
+
linear: 32
|
| 24 |
+
linear_alpha: 32
|
| 25 |
+
save:
|
| 26 |
+
dtype: float16 # precision to save
|
| 27 |
+
save_every: 250 # save every this many steps
|
| 28 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 29 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 30 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 31 |
+
# hf_repo_id: your-username/your-model-slug
|
| 32 |
+
# hf_private: true #whether the repo is private or public
|
| 33 |
+
datasets:
|
| 34 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 35 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 36 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 37 |
+
# on windows, escape back slashes with another backslash so
|
| 38 |
+
# "C:\\path\\to\\images\\folder"
|
| 39 |
+
# AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
|
| 40 |
+
# it works well for characters, but not as well for "actions"
|
| 41 |
+
- folder_path: "/path/to/images/folder"
|
| 42 |
+
caption_ext: "txt"
|
| 43 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 44 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 45 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 46 |
+
resolution: [ 632 ] # will be around 480p
|
| 47 |
+
train:
|
| 48 |
+
batch_size: 1
|
| 49 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 50 |
+
gradient_accumulation: 1
|
| 51 |
+
train_unet: true
|
| 52 |
+
train_text_encoder: false # probably won't work with wan
|
| 53 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 54 |
+
noise_scheduler: "flowmatch" # for training only
|
| 55 |
+
timestep_type: 'sigmoid'
|
| 56 |
+
optimizer: "adamw8bit"
|
| 57 |
+
lr: 1e-4
|
| 58 |
+
optimizer_params:
|
| 59 |
+
weight_decay: 1e-4
|
| 60 |
+
# uncomment this to skip the pre training sample
|
| 61 |
+
# skip_first_sample: true
|
| 62 |
+
# uncomment to completely disable sampling
|
| 63 |
+
# disable_sampling: true
|
| 64 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 65 |
+
ema_config:
|
| 66 |
+
use_ema: true
|
| 67 |
+
ema_decay: 0.99
|
| 68 |
+
dtype: bf16
|
| 69 |
+
# required for 24GB cards
|
| 70 |
+
# this will encode your trigger word and use those embeddings for every image in the dataset
|
| 71 |
+
unload_text_encoder: true
|
| 72 |
+
model:
|
| 73 |
+
# huggingface model name or path
|
| 74 |
+
name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
| 75 |
+
arch: 'wan21'
|
| 76 |
+
# these settings will save as much vram as possible
|
| 77 |
+
quantize: true
|
| 78 |
+
quantize_te: true
|
| 79 |
+
low_vram: true
|
| 80 |
+
sample:
|
| 81 |
+
sampler: "flowmatch"
|
| 82 |
+
sample_every: 250 # sample every this many steps
|
| 83 |
+
width: 832
|
| 84 |
+
height: 480
|
| 85 |
+
num_frames: 40
|
| 86 |
+
fps: 15
|
| 87 |
+
# samples take a long time. so use them sparingly
|
| 88 |
+
# samples will be animated webp files, if you don't see them animated, open in a browser.
|
| 89 |
+
prompts:
|
| 90 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 91 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 92 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 93 |
+
neg: ""
|
| 94 |
+
seed: 42
|
| 95 |
+
walk_seed: true
|
| 96 |
+
guidance_scale: 5
|
| 97 |
+
sample_steps: 30
|
| 98 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 99 |
+
meta:
|
| 100 |
+
name: "[name]"
|
| 101 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_wan21_1b_24gb.yaml
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
# this name will be the folder and filename name
|
| 5 |
+
name: "my_first_wan21_1b_lora_v1"
|
| 6 |
+
process:
|
| 7 |
+
- type: 'sd_trainer'
|
| 8 |
+
# root folder to save training sessions/samples/weights
|
| 9 |
+
training_folder: "output"
|
| 10 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 11 |
+
# performance_log_every: 1000
|
| 12 |
+
device: cuda:0
|
| 13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
| 14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
| 15 |
+
# trigger_word: "p3r5on"
|
| 16 |
+
network:
|
| 17 |
+
type: "lora"
|
| 18 |
+
linear: 32
|
| 19 |
+
linear_alpha: 32
|
| 20 |
+
save:
|
| 21 |
+
dtype: float16 # precision to save
|
| 22 |
+
save_every: 250 # save every this many steps
|
| 23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
| 25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
| 26 |
+
# hf_repo_id: your-username/your-model-slug
|
| 27 |
+
# hf_private: true #whether the repo is private or public
|
| 28 |
+
datasets:
|
| 29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
| 31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
| 32 |
+
# on windows, escape back slashes with another backslash so
|
| 33 |
+
# "C:\\path\\to\\images\\folder"
|
| 34 |
+
# AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
|
| 35 |
+
# it works well for characters, but not as well for "actions"
|
| 36 |
+
- folder_path: "/path/to/images/folder"
|
| 37 |
+
caption_ext: "txt"
|
| 38 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 39 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
| 40 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
| 41 |
+
resolution: [ 632 ] # will be around 480p
|
| 42 |
+
train:
|
| 43 |
+
batch_size: 1
|
| 44 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 45 |
+
gradient_accumulation: 1
|
| 46 |
+
train_unet: true
|
| 47 |
+
train_text_encoder: false # probably won't work with wan
|
| 48 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 49 |
+
noise_scheduler: "flowmatch" # for training only
|
| 50 |
+
timestep_type: 'sigmoid'
|
| 51 |
+
optimizer: "adamw8bit"
|
| 52 |
+
lr: 1e-4
|
| 53 |
+
optimizer_params:
|
| 54 |
+
weight_decay: 1e-4
|
| 55 |
+
# uncomment this to skip the pre training sample
|
| 56 |
+
# skip_first_sample: true
|
| 57 |
+
# uncomment to completely disable sampling
|
| 58 |
+
# disable_sampling: true
|
| 59 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
| 60 |
+
ema_config:
|
| 61 |
+
use_ema: true
|
| 62 |
+
ema_decay: 0.99
|
| 63 |
+
dtype: bf16
|
| 64 |
+
model:
|
| 65 |
+
# huggingface model name or path
|
| 66 |
+
name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
| 67 |
+
arch: 'wan21'
|
| 68 |
+
quantize_te: true # saves vram
|
| 69 |
+
sample:
|
| 70 |
+
sampler: "flowmatch"
|
| 71 |
+
sample_every: 250 # sample every this many steps
|
| 72 |
+
width: 832
|
| 73 |
+
height: 480
|
| 74 |
+
num_frames: 40
|
| 75 |
+
fps: 15
|
| 76 |
+
# samples take a long time. so use them sparingly
|
| 77 |
+
# samples will be animated webp files, if you don't see them animated, open in a browser.
|
| 78 |
+
prompts:
|
| 79 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 80 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 81 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 82 |
+
neg: ""
|
| 83 |
+
seed: 42
|
| 84 |
+
walk_seed: true
|
| 85 |
+
guidance_scale: 5
|
| 86 |
+
sample_steps: 30
|
| 87 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 88 |
+
meta:
|
| 89 |
+
name: "[name]"
|
| 90 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_lora_wan22_14b_24gb.yaml
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# this example focuses mainly for training Wan2.2 14b on images. It will work for video as well by increasing
|
| 2 |
+
# the number of frames in the dataset and samples. Training on and generating video is very VRAM intensive.
|
| 3 |
+
---
|
| 4 |
+
job: extension
|
| 5 |
+
config:
|
| 6 |
+
# this name will be the folder and filename name
|
| 7 |
+
name: "my_first_wan22_14b_lora_v1"
|
| 8 |
+
process:
|
| 9 |
+
- type: 'sd_trainer'
|
| 10 |
+
# root folder to save training sessions/samples/weights
|
| 11 |
+
training_folder: "output"
|
| 12 |
+
# uncomment to see performance stats in the terminal every N steps
|
| 13 |
+
# performance_log_every: 1000
|
| 14 |
+
device: cuda:0
|
| 15 |
+
# Use a trigger word if train.unload_text_encoder is true, however, if caching text embeddings, do not use a trigger word
|
| 16 |
+
# trigger_word: "p3r5on"
|
| 17 |
+
network:
|
| 18 |
+
type: "lora"
|
| 19 |
+
linear: 32
|
| 20 |
+
linear_alpha: 32
|
| 21 |
+
save:
|
| 22 |
+
dtype: float16 # precision to save
|
| 23 |
+
save_every: 250 # save every this many steps
|
| 24 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
| 25 |
+
datasets:
|
| 26 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
| 27 |
+
# for instance image2.jpg and image2.txt.
|
| 28 |
+
# "C:\\path\\to\\images\\folder"
|
| 29 |
+
- folder_path: "/path/to/images/or/video/folder"
|
| 30 |
+
caption_ext: "txt"
|
| 31 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
| 32 |
+
# number of frames to extract from your video. It will automatically extract them evenly spaced
|
| 33 |
+
# set to 1 frame for images
|
| 34 |
+
num_frames: 1
|
| 35 |
+
resolution: [ 512, 768, 1024]
|
| 36 |
+
train:
|
| 37 |
+
batch_size: 1
|
| 38 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
| 39 |
+
gradient_accumulation: 1
|
| 40 |
+
train_unet: true
|
| 41 |
+
train_text_encoder: false # probably won't work with wan
|
| 42 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
| 43 |
+
noise_scheduler: "flowmatch" # for training only
|
| 44 |
+
timestep_type: 'linear'
|
| 45 |
+
optimizer: "adamw8bit"
|
| 46 |
+
lr: 1e-4
|
| 47 |
+
optimizer_params:
|
| 48 |
+
weight_decay: 1e-4
|
| 49 |
+
# uncomment this to skip the pre training sample
|
| 50 |
+
# skip_first_sample: true
|
| 51 |
+
# uncomment to completely disable sampling
|
| 52 |
+
# disable_sampling: true
|
| 53 |
+
dtype: bf16
|
| 54 |
+
|
| 55 |
+
# IMPORTANT: this is for Wan 2.2 MOE. It will switch training one stage or the other every this many steps
|
| 56 |
+
switch_boundary_every: 10
|
| 57 |
+
|
| 58 |
+
# required for 24GB cards. You must do either unload_text_encoder or cache_text_embeddings but not both
|
| 59 |
+
|
| 60 |
+
# this will encode your trigger word and use those embeddings for every image in the dataset, captions will be ignored
|
| 61 |
+
# unload_text_encoder: true
|
| 62 |
+
|
| 63 |
+
# this will cache all captions in your dataset.
|
| 64 |
+
cache_text_embeddings: true
|
| 65 |
+
|
| 66 |
+
model:
|
| 67 |
+
# huggingface model name or path, this one if bf16, vs the float32 of the official repo
|
| 68 |
+
name_or_path: "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16"
|
| 69 |
+
arch: 'wan22_14b'
|
| 70 |
+
quantize: true
|
| 71 |
+
# This will pull and use a custom Accuracy Recovery Adapter to train at 4bit
|
| 72 |
+
qtype: "uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors"
|
| 73 |
+
quantize_te: true
|
| 74 |
+
qtype_te: "qfloat8"
|
| 75 |
+
low_vram: true
|
| 76 |
+
model_kwargs:
|
| 77 |
+
# you can train high noise, low noise, or both. With low vram it will automatically unload the one not being trained.
|
| 78 |
+
train_high_noise: true
|
| 79 |
+
train_low_noise: true
|
| 80 |
+
sample:
|
| 81 |
+
sampler: "flowmatch"
|
| 82 |
+
sample_every: 250 # sample every this many steps
|
| 83 |
+
width: 1024
|
| 84 |
+
height: 1024
|
| 85 |
+
# set to 1 for images
|
| 86 |
+
num_frames: 1
|
| 87 |
+
fps: 16
|
| 88 |
+
# samples take a long time. so use them sparingly
|
| 89 |
+
# samples will be animated webp files, if you don't see them animated, open in a browser.
|
| 90 |
+
prompts:
|
| 91 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
| 92 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
| 93 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
| 94 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
| 95 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
| 96 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
| 97 |
+
- "a bear building a log cabin in the snow covered mountains"
|
| 98 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
| 99 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
| 100 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
| 101 |
+
- "a man holding a sign that says, 'this is a sign'"
|
| 102 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
| 103 |
+
neg: ""
|
| 104 |
+
seed: 42
|
| 105 |
+
walk_seed: true
|
| 106 |
+
guidance_scale: 3.5
|
| 107 |
+
sample_steps: 25
|
| 108 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
| 109 |
+
meta:
|
| 110 |
+
name: "[name]"
|
| 111 |
+
version: '1.0'
|
ai-toolkit/config/examples/train_slider.example.yml
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# This is in yaml format. You can use json if you prefer
|
| 3 |
+
# I like both but yaml is easier to write
|
| 4 |
+
# Plus it has comments which is nice for documentation
|
| 5 |
+
# This is the config I use on my sliders, It is solid and tested
|
| 6 |
+
job: train
|
| 7 |
+
config:
|
| 8 |
+
# the name will be used to create a folder in the output folder
|
| 9 |
+
# it will also replace any [name] token in the rest of this config
|
| 10 |
+
name: detail_slider_v1
|
| 11 |
+
# folder will be created with name above in folder below
|
| 12 |
+
# it can be relative to the project root or absolute
|
| 13 |
+
training_folder: "output/LoRA"
|
| 14 |
+
device: cuda:0 # cpu, cuda:0, etc
|
| 15 |
+
# for tensorboard logging, we will make a subfolder for this job
|
| 16 |
+
log_dir: "output/.tensorboard"
|
| 17 |
+
# you can stack processes for other jobs, It is not tested with sliders though
|
| 18 |
+
# just use one for now
|
| 19 |
+
process:
|
| 20 |
+
- type: slider # tells runner to run the slider process
|
| 21 |
+
# network is the LoRA network for a slider, I recommend to leave this be
|
| 22 |
+
network:
|
| 23 |
+
# network type lierla is traditional LoRA that works everywhere, only linear layers
|
| 24 |
+
type: "lierla"
|
| 25 |
+
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
|
| 26 |
+
linear: 8
|
| 27 |
+
linear_alpha: 4 # Do about half of rank
|
| 28 |
+
# training config
|
| 29 |
+
train:
|
| 30 |
+
# this is also used in sampling. Stick with ddpm unless you know what you are doing
|
| 31 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
| 32 |
+
# how many steps to train. More is not always better. I rarely go over 1000
|
| 33 |
+
steps: 500
|
| 34 |
+
# I have had good results with 4e-4 to 1e-4 at 500 steps
|
| 35 |
+
lr: 2e-4
|
| 36 |
+
# enables gradient checkpoint, saves vram, leave it on
|
| 37 |
+
gradient_checkpointing: true
|
| 38 |
+
# train the unet. I recommend leaving this true
|
| 39 |
+
train_unet: true
|
| 40 |
+
# train the text encoder. I don't recommend this unless you have a special use case
|
| 41 |
+
# for sliders we are adjusting representation of the concept (unet),
|
| 42 |
+
# not the description of it (text encoder)
|
| 43 |
+
train_text_encoder: false
|
| 44 |
+
# same as from sd-scripts, not fully tested but should speed up training
|
| 45 |
+
min_snr_gamma: 5.0
|
| 46 |
+
# just leave unless you know what you are doing
|
| 47 |
+
# also supports "dadaptation" but set lr to 1 if you use that,
|
| 48 |
+
# but it learns too fast and I don't recommend it
|
| 49 |
+
optimizer: "adamw"
|
| 50 |
+
# only constant for now
|
| 51 |
+
lr_scheduler: "constant"
|
| 52 |
+
# we randomly denoise random num of steps form 1 to this number
|
| 53 |
+
# while training. Just leave it
|
| 54 |
+
max_denoising_steps: 40
|
| 55 |
+
# works great at 1. I do 1 even with my 4090.
|
| 56 |
+
# higher may not work right with newer single batch stacking code anyway
|
| 57 |
+
batch_size: 1
|
| 58 |
+
# bf16 works best if your GPU supports it (modern)
|
| 59 |
+
dtype: bf16 # fp32, bf16, fp16
|
| 60 |
+
# if you have it, use it. It is faster and better
|
| 61 |
+
# torch 2.0 doesnt need xformers anymore, only use if you have lower version
|
| 62 |
+
# xformers: true
|
| 63 |
+
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
| 64 |
+
# although, the way we train sliders is comparative, so it probably won't work anyway
|
| 65 |
+
noise_offset: 0.0
|
| 66 |
+
# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
|
| 67 |
+
|
| 68 |
+
# the model to train the LoRA network on
|
| 69 |
+
model:
|
| 70 |
+
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
| 71 |
+
name_or_path: "runwayml/stable-diffusion-v1-5"
|
| 72 |
+
is_v2: false # for v2 models
|
| 73 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
| 74 |
+
# has some issues with the dual text encoder and the way we train sliders
|
| 75 |
+
# it works bit weights need to probably be higher to see it.
|
| 76 |
+
is_xl: false # for SDXL models
|
| 77 |
+
|
| 78 |
+
# saving config
|
| 79 |
+
save:
|
| 80 |
+
dtype: float16 # precision to save. I recommend float16
|
| 81 |
+
save_every: 50 # save every this many steps
|
| 82 |
+
# this will remove step counts more than this number
|
| 83 |
+
# allows you to save more often in case of a crash without filling up your drive
|
| 84 |
+
max_step_saves_to_keep: 2
|
| 85 |
+
|
| 86 |
+
# sampling config
|
| 87 |
+
sample:
|
| 88 |
+
# must match train.noise_scheduler, this is not used here
|
| 89 |
+
# but may be in future and in other processes
|
| 90 |
+
sampler: "ddpm"
|
| 91 |
+
# sample every this many steps
|
| 92 |
+
sample_every: 20
|
| 93 |
+
# image size
|
| 94 |
+
width: 512
|
| 95 |
+
height: 512
|
| 96 |
+
# prompts to use for sampling. Do as many as you want, but it slows down training
|
| 97 |
+
# pick ones that will best represent the concept you are trying to adjust
|
| 98 |
+
# allows some flags after the prompt
|
| 99 |
+
# --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
|
| 100 |
+
# slide are good tests. will inherit sample.network_multiplier if not set
|
| 101 |
+
# --n [string] # negative prompt, will inherit sample.neg if not set
|
| 102 |
+
# Only 75 tokens allowed currently
|
| 103 |
+
# I like to do a wide positive and negative spread so I can see a good range and stop
|
| 104 |
+
# early if the network is braking down
|
| 105 |
+
prompts:
|
| 106 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
|
| 107 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
|
| 108 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
|
| 109 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
|
| 110 |
+
- "a golden retriever sitting on a leather couch, --m -5"
|
| 111 |
+
- "a golden retriever sitting on a leather couch --m -3"
|
| 112 |
+
- "a golden retriever sitting on a leather couch --m 3"
|
| 113 |
+
- "a golden retriever sitting on a leather couch --m 5"
|
| 114 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
|
| 115 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
|
| 116 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
|
| 117 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
|
| 118 |
+
# negative prompt used on all prompts above as default if they don't have one
|
| 119 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
|
| 120 |
+
# seed for sampling. 42 is the answer for everything
|
| 121 |
+
seed: 42
|
| 122 |
+
# walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
|
| 123 |
+
# will start over on next sample_every so s1 is always seed
|
| 124 |
+
# works well if you use same prompt but want different results
|
| 125 |
+
walk_seed: false
|
| 126 |
+
# cfg scale (4 to 10 is good)
|
| 127 |
+
guidance_scale: 7
|
| 128 |
+
# sampler steps (20 to 30 is good)
|
| 129 |
+
sample_steps: 20
|
| 130 |
+
# default network multiplier for all prompts
|
| 131 |
+
# since we are training a slider, I recommend overriding this with --m [number]
|
| 132 |
+
# in the prompts above to get both sides of the slider
|
| 133 |
+
network_multiplier: 1.0
|
| 134 |
+
|
| 135 |
+
# logging information
|
| 136 |
+
logging:
|
| 137 |
+
log_every: 10 # log every this many steps
|
| 138 |
+
use_wandb: false # not supported yet
|
| 139 |
+
verbose: false # probably done need unless you are debugging
|
| 140 |
+
|
| 141 |
+
# slider training config, best for last
|
| 142 |
+
slider:
|
| 143 |
+
# resolutions to train on. [ width, height ]. This is less important for sliders
|
| 144 |
+
# as we are not teaching the model anything it doesn't already know
|
| 145 |
+
# but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
|
| 146 |
+
# and [ 1024, 1024 ] for sd_xl
|
| 147 |
+
# you can do as many as you want here
|
| 148 |
+
resolutions:
|
| 149 |
+
- [ 512, 512 ]
|
| 150 |
+
# - [ 512, 768 ]
|
| 151 |
+
# - [ 768, 768 ]
|
| 152 |
+
# slider training uses 4 combined steps for a single round. This will do it in one gradient
|
| 153 |
+
# step. It is highly optimized and shouldn't take anymore vram than doing without it,
|
| 154 |
+
# since we break down batches for gradient accumulation now. so just leave it on.
|
| 155 |
+
batch_full_slide: true
|
| 156 |
+
# These are the concepts to train on. You can do as many as you want here,
|
| 157 |
+
# but they can conflict outweigh each other. Other than experimenting, I recommend
|
| 158 |
+
# just doing one for good results
|
| 159 |
+
targets:
|
| 160 |
+
# target_class is the base concept we are adjusting the representation of
|
| 161 |
+
# for example, if we are adjusting the representation of a person, we would use "person"
|
| 162 |
+
# if we are adjusting the representation of a cat, we would use "cat" It is not
|
| 163 |
+
# a keyword necessarily but what the model understands the concept to represent.
|
| 164 |
+
# "person" will affect men, women, children, etc but will not affect cats, dogs, etc
|
| 165 |
+
# it is the models base general understanding of the concept and everything it represents
|
| 166 |
+
# you can leave it blank to affect everything. In this example, we are adjusting
|
| 167 |
+
# detail, so we will leave it blank to affect everything
|
| 168 |
+
- target_class: ""
|
| 169 |
+
# positive is the prompt for the positive side of the slider.
|
| 170 |
+
# It is the concept that will be excited and amplified in the model when we slide the slider
|
| 171 |
+
# to the positive side and forgotten / inverted when we slide
|
| 172 |
+
# the slider to the negative side. It is generally best to include the target_class in
|
| 173 |
+
# the prompt. You want it to be the extreme of what you want to train on. For example,
|
| 174 |
+
# if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
|
| 175 |
+
# as the prompt. Not just "fat person"
|
| 176 |
+
# max 75 tokens for now
|
| 177 |
+
positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
|
| 178 |
+
# negative is the prompt for the negative side of the slider and works the same as positive
|
| 179 |
+
# it does not necessarily work the same as a negative prompt when generating images
|
| 180 |
+
# these need to be polar opposites.
|
| 181 |
+
# max 76 tokens for now
|
| 182 |
+
negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
|
| 183 |
+
# the loss for this target is multiplied by this number.
|
| 184 |
+
# if you are doing more than one target it may be good to set less important ones
|
| 185 |
+
# to a lower number like 0.1 so they don't outweigh the primary target
|
| 186 |
+
weight: 1.0
|
| 187 |
+
# shuffle the prompts split by the comma. We will run every combination randomly
|
| 188 |
+
# this will make the LoRA more robust. You probably want this on unless prompt order
|
| 189 |
+
# is important for some reason
|
| 190 |
+
shuffle: true
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# anchors are prompts that we will try to hold on to while training the slider
|
| 194 |
+
# these are NOT necessary and can prevent the slider from converging if not done right
|
| 195 |
+
# leave them off if you are having issues, but they can help lock the network
|
| 196 |
+
# on certain concepts to help prevent catastrophic forgetting
|
| 197 |
+
# you want these to generate an image that is not your target_class, but close to it
|
| 198 |
+
# is fine as long as it does not directly overlap it.
|
| 199 |
+
# For example, if you are training on a person smiling,
|
| 200 |
+
# you could use "a person with a face mask" as an anchor. It is a person, the image is the same
|
| 201 |
+
# regardless if they are smiling or not, however, the closer the concept is to the target_class
|
| 202 |
+
# the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
|
| 203 |
+
# for close concepts, you want to be closer to 0.1 or 0.2
|
| 204 |
+
# these will slow down training. I am leaving them off for the demo
|
| 205 |
+
|
| 206 |
+
# anchors:
|
| 207 |
+
# - prompt: "a woman"
|
| 208 |
+
# neg_prompt: "animal"
|
| 209 |
+
# # the multiplier applied to the LoRA when this is run.
|
| 210 |
+
# # higher will give it more weight but also help keep the lora from collapsing
|
| 211 |
+
# multiplier: 1.0
|
| 212 |
+
# - prompt: "a man"
|
| 213 |
+
# neg_prompt: "animal"
|
| 214 |
+
# multiplier: 1.0
|
| 215 |
+
# - prompt: "a person"
|
| 216 |
+
# neg_prompt: "animal"
|
| 217 |
+
# multiplier: 1.0
|
| 218 |
+
|
| 219 |
+
# You can put any information you want here, and it will be saved in the model.
|
| 220 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
| 221 |
+
# It is saved in the model so be aware of that. The software will include this
|
| 222 |
+
# plus some other information for you automatically
|
| 223 |
+
meta:
|
| 224 |
+
# [name] gets replaced with the name above
|
| 225 |
+
name: "[name]"
|
| 226 |
+
# version: '1.0'
|
| 227 |
+
# creator:
|
| 228 |
+
# name: Your Name
|
| 229 |
+
# email: your@gmail.com
|
| 230 |
+
# website: https://your.website
|
ai-toolkit/dgx_instructions.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AI Toolkit by Ostris
|
| 2 |
+
|
| 3 |
+
## DGX OS installation instructions
|
| 4 |
+
|
| 5 |
+
You need to use Python 3.11 to run AI Toolkit on DGX OS. The easiest way to do this without affecting the system installation of Python is to create a virtual environment with **miniconda**, which allows you to specify the version of Python to use in the environment.
|
| 6 |
+
|
| 7 |
+
This guide will assume you have a fresh installation of DGX OS, and will guide you through the installation of all requirements.
|
| 8 |
+
|
| 9 |
+
### Installation instructions for DGX OS:
|
| 10 |
+
|
| 11 |
+
**1) Get Python 3.11 (via miniconda)**
|
| 12 |
+
|
| 13 |
+
Install the latest version of miniconda:
|
| 14 |
+
```
|
| 15 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
|
| 16 |
+
chmod u+x Miniconda3-latest-Linux-aarch64.sh
|
| 17 |
+
./Miniconda3-latest-Linux-aarch64.sh
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
Restart your bash or ssh session. If miniconda was installed successfully, it will automatically load the 'base' environment by default. If you want to disable this behaviour, run:
|
| 21 |
+
```
|
| 22 |
+
conda config --set auto_activate_base false
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Now you can create a Python 3.11 environment for ai-toolkit:
|
| 26 |
+
```
|
| 27 |
+
conda create --name ai-toolkit python=3.11
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Then activate the environment with:
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
conda activate ai-toolkit
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
**2) Install PyTorch**
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
pip3 install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu130
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
**3) Install the remaining requirements (dgx_requirements.txt)**
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
pip3 install -r dgx_requirements.txt
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Running the UI on DGX OS:
|
| 51 |
+
|
| 52 |
+
Running the UI is not that different from doing it on other systems, however, you need to install the ARM64 version of NodeJS for Linux, which is compatible with the NVIDIA Grace CPU.
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
**1) Install Node.js**
|
| 56 |
+
|
| 57 |
+
Download a Linux ARM64 build of Node.js from: https://nodejs.org (for example: https://nodejs.org/dist/v24.11.1/node-v24.11.1-linux-arm64.tar.xz)
|
| 58 |
+
|
| 59 |
+
Extract it and add the bin directory to your path. I extracted it to **/opt** and added the following to my ~/.bashrc file:
|
| 60 |
+
```
|
| 61 |
+
export PATH=“/opt/node-v24.11.1-linux-arm64/bin:$PATH”
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
**2) Compile and run the Node.js UI**
|
| 66 |
+
|
| 67 |
+
Change to the ui directory, then build and run the UI:
|
| 68 |
+
```
|
| 69 |
+
cd ui
|
| 70 |
+
npm run build_and_start
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
If all went well, you’ll be able to access the UI on port 8675 and start training.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
<details>
|
| 77 |
+
<summary>Troubleshooting issues</summary>
|
| 78 |
+
If you’re not getting any output when starting a training job from the UI, it’s probably crashing before the process started, the best way to debug these issues is to run the python training script directly (which is normally started by the UI). To do this, set up a training job in the UI, go to the advanced config screen, copy and paste the configuration into a file like train.yaml, then run the training script like this with the conda virtual environment active:
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
python run.py path/to/train.yaml
|
| 82 |
+
```
|
| 83 |
+
</details>
|
| 84 |
+
<br>
|
ai-toolkit/dgx_requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# You need to use Python 3.11, the easiest way to get this on DGX OS without impacting the system version of Python is to create an environment with miniconda.
|
| 2 |
+
|
| 3 |
+
# specific dependency versions needed on DGX OS devices:
|
| 4 |
+
scipy==1.16.0
|
| 5 |
+
tifffile==2025.6.11
|
| 6 |
+
imageio==2.37.0
|
| 7 |
+
scikit_image==0.25.2
|
| 8 |
+
clean_fid==0.1.35
|
| 9 |
+
pywavelets==1.9.0
|
| 10 |
+
contourpy==1.3.3
|
| 11 |
+
opencv_python_headless==4.11.0.86
|
| 12 |
+
|
| 13 |
+
-r requirements_base.txt
|
ai-toolkit/docker-compose.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: "3.8"
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
ai-toolkit:
|
| 5 |
+
image: ostris/aitoolkit:latest
|
| 6 |
+
restart: unless-stopped
|
| 7 |
+
ports:
|
| 8 |
+
- "8675:8675"
|
| 9 |
+
volumes:
|
| 10 |
+
- ~/.cache/huggingface/hub:/root/.cache/huggingface/hub
|
| 11 |
+
- ./aitk_db.db:/app/ai-toolkit/aitk_db.db
|
| 12 |
+
- ./datasets:/app/ai-toolkit/datasets
|
| 13 |
+
- ./output:/app/ai-toolkit/output
|
| 14 |
+
- ./config:/app/ai-toolkit/config
|
| 15 |
+
environment:
|
| 16 |
+
- AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password}
|
| 17 |
+
- NODE_ENV=production
|
| 18 |
+
- TZ=UTC
|
| 19 |
+
deploy:
|
| 20 |
+
resources:
|
| 21 |
+
reservations:
|
| 22 |
+
devices:
|
| 23 |
+
- driver: nvidia
|
| 24 |
+
count: all
|
| 25 |
+
capabilities: [gpu]
|
ai-toolkit/docker/Dockerfile
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.8.1-devel-ubuntu24.04
|
| 2 |
+
|
| 3 |
+
LABEL authors="jaret"
|
| 4 |
+
|
| 5 |
+
# Set noninteractive to avoid timezone prompts
|
| 6 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 7 |
+
|
| 8 |
+
# ref https://en.wikipedia.org/wiki/CUDA
|
| 9 |
+
ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0 10.0 12.0"
|
| 10 |
+
|
| 11 |
+
# Install dependencies
|
| 12 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
| 13 |
+
git \
|
| 14 |
+
curl \
|
| 15 |
+
build-essential \
|
| 16 |
+
cmake \
|
| 17 |
+
wget \
|
| 18 |
+
python3.12 \
|
| 19 |
+
python3-pip \
|
| 20 |
+
python3-dev \
|
| 21 |
+
python3-setuptools \
|
| 22 |
+
python3-wheel \
|
| 23 |
+
python3-venv \
|
| 24 |
+
ffmpeg \
|
| 25 |
+
tmux \
|
| 26 |
+
htop \
|
| 27 |
+
nvtop \
|
| 28 |
+
python3-opencv \
|
| 29 |
+
openssh-client \
|
| 30 |
+
openssh-server \
|
| 31 |
+
openssl \
|
| 32 |
+
rsync \
|
| 33 |
+
unzip \
|
| 34 |
+
&& apt-get clean \
|
| 35 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 36 |
+
|
| 37 |
+
# Install nodejs
|
| 38 |
+
WORKDIR /tmp
|
| 39 |
+
RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \
|
| 40 |
+
bash nodesource_setup.sh && \
|
| 41 |
+
apt-get update && \
|
| 42 |
+
apt-get install -y nodejs && \
|
| 43 |
+
apt-get clean && \
|
| 44 |
+
rm -rf /var/lib/apt/lists/*
|
| 45 |
+
|
| 46 |
+
WORKDIR /app
|
| 47 |
+
|
| 48 |
+
# Set aliases for python and pip
|
| 49 |
+
RUN ln -s /usr/bin/python3 /usr/bin/python
|
| 50 |
+
|
| 51 |
+
# install pytorch before cache bust to avoid redownloading pytorch
|
| 52 |
+
RUN pip install --no-cache-dir torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128 --break-system-packages
|
| 53 |
+
|
| 54 |
+
WORKDIR /app/ai-toolkit
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------- #
|
| 57 |
+
# Dependency layers come BEFORE the source clone so they are only rebuilt (and
|
| 58 |
+
# only need to be re-pulled by servers) when the dependency manifests change,
|
| 59 |
+
# not on every code change.
|
| 60 |
+
# ---------------------------------------------------------------------------- #
|
| 61 |
+
|
| 62 |
+
# Install Python dependencies (only re-runs when the requirements files change)
|
| 63 |
+
COPY requirements.txt requirements_base.txt /app/ai-toolkit/
|
| 64 |
+
RUN pip install --no-cache-dir --break-system-packages -r requirements.txt && \
|
| 65 |
+
pip install setuptools==69.5.1 --no-cache-dir --break-system-packages
|
| 66 |
+
|
| 67 |
+
# Install Node dependencies (only re-runs when package.json / package-lock.json change)
|
| 68 |
+
COPY ui/package.json ui/package-lock.json /app/ai-toolkit/ui/
|
| 69 |
+
RUN cd /app/ai-toolkit/ui && npm ci
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------- #
|
| 72 |
+
# Source code comes LAST. Only this layer (plus the UI build below) is rebuilt
|
| 73 |
+
# on a code change, so servers only re-pull the (small) source, not the deps.
|
| 74 |
+
# Clone to a temp dir and rsync the source in, preserving the dependency dirs
|
| 75 |
+
# already populated above (ui/node_modules) and the manifests already used.
|
| 76 |
+
# ---------------------------------------------------------------------------- #
|
| 77 |
+
ARG CACHEBUST=1234
|
| 78 |
+
ARG GIT_COMMIT=main
|
| 79 |
+
RUN echo "Cache bust: ${CACHEBUST}" && \
|
| 80 |
+
git clone https://github.com/ostris/ai-toolkit.git /tmp/ai-toolkit-src && \
|
| 81 |
+
cd /tmp/ai-toolkit-src && \
|
| 82 |
+
git checkout ${GIT_COMMIT} && \
|
| 83 |
+
rsync -a --delete \
|
| 84 |
+
--exclude 'ui/node_modules' \
|
| 85 |
+
--exclude 'requirements.txt' \
|
| 86 |
+
--exclude 'ui/package.json' \
|
| 87 |
+
--exclude 'ui/package-lock.json' \
|
| 88 |
+
/tmp/ai-toolkit-src/ /app/ai-toolkit/ && \
|
| 89 |
+
rm -rf /tmp/ai-toolkit-src
|
| 90 |
+
|
| 91 |
+
# Build UI (re-runs on code change, but reuses the cached node_modules above).
|
| 92 |
+
# update_db runs first because it does `prisma generate`, which creates the
|
| 93 |
+
# @prisma/client types the TS build needs. In the old layout generate happened
|
| 94 |
+
# as a side effect of npm install seeing the schema; now the source arrives
|
| 95 |
+
# after npm ci, so run it explicitly before the build.
|
| 96 |
+
RUN cd /app/ai-toolkit/ui && \
|
| 97 |
+
npm run update_db && \
|
| 98 |
+
npm run build
|
| 99 |
+
|
| 100 |
+
# Expose port (assuming the application runs on port 3000)
|
| 101 |
+
EXPOSE 8675
|
| 102 |
+
|
| 103 |
+
WORKDIR /
|
| 104 |
+
|
| 105 |
+
COPY docker/start.sh /start.sh
|
| 106 |
+
RUN chmod +x /start.sh
|
| 107 |
+
|
| 108 |
+
CMD ["/start.sh"]
|
ai-toolkit/docker/start.sh
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e # Exit the script if any statement returns a non-true return value
|
| 3 |
+
|
| 4 |
+
# ref https://github.com/runpod/containers/blob/main/container-template/start.sh
|
| 5 |
+
|
| 6 |
+
# ---------------------------------------------------------------------------- #
|
| 7 |
+
# Function Definitions #
|
| 8 |
+
# ---------------------------------------------------------------------------- #
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Setup ssh
|
| 12 |
+
setup_ssh() {
|
| 13 |
+
if [[ $PUBLIC_KEY ]]; then
|
| 14 |
+
echo "Setting up SSH..."
|
| 15 |
+
mkdir -p ~/.ssh
|
| 16 |
+
echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
|
| 17 |
+
chmod 700 -R ~/.ssh
|
| 18 |
+
|
| 19 |
+
if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then
|
| 20 |
+
ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N ''
|
| 21 |
+
echo "RSA key fingerprint:"
|
| 22 |
+
ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then
|
| 26 |
+
ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N ''
|
| 27 |
+
echo "DSA key fingerprint:"
|
| 28 |
+
ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then
|
| 32 |
+
ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N ''
|
| 33 |
+
echo "ECDSA key fingerprint:"
|
| 34 |
+
ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then
|
| 38 |
+
ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N ''
|
| 39 |
+
echo "ED25519 key fingerprint:"
|
| 40 |
+
ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
service ssh start
|
| 44 |
+
|
| 45 |
+
echo "SSH host keys:"
|
| 46 |
+
for key in /etc/ssh/*.pub; do
|
| 47 |
+
echo "Key: $key"
|
| 48 |
+
ssh-keygen -lf $key
|
| 49 |
+
done
|
| 50 |
+
fi
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Export env vars
|
| 54 |
+
export_env_vars() {
|
| 55 |
+
echo "Exporting environment variables..."
|
| 56 |
+
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment
|
| 57 |
+
echo 'source /etc/rp_environment' >> ~/.bashrc
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------- #
|
| 61 |
+
# Main Program #
|
| 62 |
+
# ---------------------------------------------------------------------------- #
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
echo "Pod Started"
|
| 66 |
+
|
| 67 |
+
setup_ssh
|
| 68 |
+
export_env_vars
|
| 69 |
+
echo "Starting AI Toolkit UI..."
|
| 70 |
+
cd /app/ai-toolkit/ui && npm run start
|
ai-toolkit/extensions/example/ExampleMergeModels.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
from jobs.process import BaseExtensionProcess
|
| 6 |
+
from toolkit.config_modules import ModelConfig
|
| 7 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
| 8 |
+
from toolkit.train_tools import get_torch_dtype
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# Type check imports. Prevents circular imports
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from jobs import ExtensionJob
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# extend standard config classes to add weight
|
| 17 |
+
class ModelInputConfig(ModelConfig):
|
| 18 |
+
def __init__(self, **kwargs):
|
| 19 |
+
super().__init__(**kwargs)
|
| 20 |
+
self.weight = kwargs.get('weight', 1.0)
|
| 21 |
+
# overwrite default dtype unless user specifies otherwise
|
| 22 |
+
# float 32 will give up better precision on the merging functions
|
| 23 |
+
self.dtype: str = kwargs.get('dtype', 'float32')
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def flush():
|
| 27 |
+
torch.cuda.empty_cache()
|
| 28 |
+
gc.collect()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# this is our main class process
|
| 32 |
+
class ExampleMergeModels(BaseExtensionProcess):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
process_id: int,
|
| 36 |
+
job: 'ExtensionJob',
|
| 37 |
+
config: OrderedDict
|
| 38 |
+
):
|
| 39 |
+
super().__init__(process_id, job, config)
|
| 40 |
+
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
| 41 |
+
# checking requirements. This is called before the run() function
|
| 42 |
+
# no loading models or anything like that, it is just for setting up the process
|
| 43 |
+
# all of your process intensive stuff should be done in the run() function
|
| 44 |
+
# config will have everything from the process item in the config file
|
| 45 |
+
|
| 46 |
+
# convince methods exist on BaseProcess to get config values
|
| 47 |
+
# if required is set to true and the value is not found it will throw an error
|
| 48 |
+
# you can pass a default value to get_conf() as well if it was not in the config file
|
| 49 |
+
# as well as a type to cast the value to
|
| 50 |
+
self.save_path = self.get_conf('save_path', required=True)
|
| 51 |
+
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
| 52 |
+
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
| 53 |
+
|
| 54 |
+
# build models to merge list
|
| 55 |
+
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
| 56 |
+
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
| 57 |
+
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
| 58 |
+
# inbuilt config classes located in toolkit.config_modules as well
|
| 59 |
+
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
| 60 |
+
# setup is complete. Don't load anything else here, just setup variables and stuff
|
| 61 |
+
|
| 62 |
+
# this is the entire run process be sure to call super().run() first
|
| 63 |
+
def run(self):
|
| 64 |
+
# always call first
|
| 65 |
+
super().run()
|
| 66 |
+
print(f"Running process: {self.__class__.__name__}")
|
| 67 |
+
|
| 68 |
+
# let's adjust our weights first to normalize them so the total is 1.0
|
| 69 |
+
total_weight = sum([model.weight for model in self.models_to_merge])
|
| 70 |
+
weight_adjust = 1.0 / total_weight
|
| 71 |
+
for model in self.models_to_merge:
|
| 72 |
+
model.weight *= weight_adjust
|
| 73 |
+
|
| 74 |
+
output_model: StableDiffusion = None
|
| 75 |
+
# let's do the merge, it is a good idea to use tqdm to show progress
|
| 76 |
+
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
| 77 |
+
# setup model class with our helper class
|
| 78 |
+
sd_model = StableDiffusion(
|
| 79 |
+
device=self.device,
|
| 80 |
+
model_config=model_config,
|
| 81 |
+
dtype="float32"
|
| 82 |
+
)
|
| 83 |
+
# load the model
|
| 84 |
+
sd_model.load_model()
|
| 85 |
+
|
| 86 |
+
# adjust the weight of the text encoder
|
| 87 |
+
if isinstance(sd_model.text_encoder, list):
|
| 88 |
+
# sdxl model
|
| 89 |
+
for text_encoder in sd_model.text_encoder:
|
| 90 |
+
for key, value in text_encoder.state_dict().items():
|
| 91 |
+
value *= model_config.weight
|
| 92 |
+
else:
|
| 93 |
+
# normal model
|
| 94 |
+
for key, value in sd_model.text_encoder.state_dict().items():
|
| 95 |
+
value *= model_config.weight
|
| 96 |
+
# adjust the weights of the unet
|
| 97 |
+
for key, value in sd_model.unet.state_dict().items():
|
| 98 |
+
value *= model_config.weight
|
| 99 |
+
|
| 100 |
+
if output_model is None:
|
| 101 |
+
# use this one as the base
|
| 102 |
+
output_model = sd_model
|
| 103 |
+
else:
|
| 104 |
+
# merge the models
|
| 105 |
+
# text encoder
|
| 106 |
+
if isinstance(output_model.text_encoder, list):
|
| 107 |
+
# sdxl model
|
| 108 |
+
for i, text_encoder in enumerate(output_model.text_encoder):
|
| 109 |
+
for key, value in text_encoder.state_dict().items():
|
| 110 |
+
value += sd_model.text_encoder[i].state_dict()[key]
|
| 111 |
+
else:
|
| 112 |
+
# normal model
|
| 113 |
+
for key, value in output_model.text_encoder.state_dict().items():
|
| 114 |
+
value += sd_model.text_encoder.state_dict()[key]
|
| 115 |
+
# unet
|
| 116 |
+
for key, value in output_model.unet.state_dict().items():
|
| 117 |
+
value += sd_model.unet.state_dict()[key]
|
| 118 |
+
|
| 119 |
+
# remove the model to free memory
|
| 120 |
+
del sd_model
|
| 121 |
+
flush()
|
| 122 |
+
|
| 123 |
+
# merge loop is done, let's save the model
|
| 124 |
+
print(f"Saving merged model to {self.save_path}")
|
| 125 |
+
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
| 126 |
+
print(f"Saved merged model to {self.save_path}")
|
| 127 |
+
# do cleanup here
|
| 128 |
+
del output_model
|
| 129 |
+
flush()
|
ai-toolkit/extensions/example/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
| 2 |
+
from toolkit.extension import Extension
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# We make a subclass of Extension
|
| 6 |
+
class ExampleMergeExtension(Extension):
|
| 7 |
+
# uid must be unique, it is how the extension is identified
|
| 8 |
+
uid = "example_merge_extension"
|
| 9 |
+
|
| 10 |
+
# name is the name of the extension for printing
|
| 11 |
+
name = "Example Merge Extension"
|
| 12 |
+
|
| 13 |
+
# This is where your process class is loaded
|
| 14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
| 15 |
+
@classmethod
|
| 16 |
+
def get_process(cls):
|
| 17 |
+
# import your process class here so it is only loaded when needed and return it
|
| 18 |
+
from .ExampleMergeModels import ExampleMergeModels
|
| 19 |
+
return ExampleMergeModels
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
AI_TOOLKIT_EXTENSIONS = [
|
| 23 |
+
# you can put a list of extensions here
|
| 24 |
+
ExampleMergeExtension
|
| 25 |
+
]
|
ai-toolkit/extensions/example/config/config.example.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# Always include at least one example config file to show how to use your extension.
|
| 3 |
+
# use plenty of comments so users know how to use it and what everything does
|
| 4 |
+
|
| 5 |
+
# all extensions will use this job name
|
| 6 |
+
job: extension
|
| 7 |
+
config:
|
| 8 |
+
name: 'my_awesome_merge'
|
| 9 |
+
process:
|
| 10 |
+
# Put your example processes here. This will be passed
|
| 11 |
+
# to your extension process in the config argument.
|
| 12 |
+
# the type MUST match your extension uid
|
| 13 |
+
- type: "example_merge_extension"
|
| 14 |
+
# save path for the merged model
|
| 15 |
+
save_path: "output/merge/[name].safetensors"
|
| 16 |
+
# save type
|
| 17 |
+
dtype: fp16
|
| 18 |
+
# device to run it on
|
| 19 |
+
device: cuda:0
|
| 20 |
+
# input models can only be SD1.x and SD2.x models for this example (currently)
|
| 21 |
+
models_to_merge:
|
| 22 |
+
# weights are relative, total weights will be normalized
|
| 23 |
+
# for example. If you have 2 models with weight 1.0, they will
|
| 24 |
+
# both be weighted 0.5. If you have 1 model with weight 1.0 and
|
| 25 |
+
# another with weight 2.0, the first will be weighted 1/3 and the
|
| 26 |
+
# second will be weighted 2/3
|
| 27 |
+
- name_or_path: "input/model1.safetensors"
|
| 28 |
+
weight: 1.0
|
| 29 |
+
- name_or_path: "input/model2.safetensors"
|
| 30 |
+
weight: 1.0
|
| 31 |
+
- name_or_path: "input/model3.safetensors"
|
| 32 |
+
weight: 0.3
|
| 33 |
+
- name_or_path: "input/model4.safetensors"
|
| 34 |
+
weight: 1.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# you can put any information you want here, and it will be saved in the model
|
| 38 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
| 39 |
+
# in the metadata. The software will include this plus some other information
|
| 40 |
+
meta:
|
| 41 |
+
name: "[name]" # [name] gets replaced with the name above
|
| 42 |
+
description: A short description of your model
|
| 43 |
+
version: '0.1'
|
| 44 |
+
creator:
|
| 45 |
+
name: Your Name
|
| 46 |
+
email: your@email.com
|
| 47 |
+
website: https://yourwebsite.com
|
| 48 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
ai-toolkit/extensions_built_in/advanced_generator/Img2ImgGenerator.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from diffusers import T2IAdapter
|
| 10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
| 16 |
+
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
| 17 |
+
from toolkit.sampler import get_sampler
|
| 18 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
| 19 |
+
import gc
|
| 20 |
+
import torch
|
| 21 |
+
from jobs.process import BaseExtensionProcess
|
| 22 |
+
from toolkit.data_loader import get_dataloader_from_datasets
|
| 23 |
+
from toolkit.train_tools import get_torch_dtype
|
| 24 |
+
from controlnet_aux.midas import MidasDetector
|
| 25 |
+
from diffusers.utils import load_image
|
| 26 |
+
from torchvision.transforms import ToTensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def flush():
|
| 30 |
+
torch.cuda.empty_cache()
|
| 31 |
+
gc.collect()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GenerateConfig:
|
| 38 |
+
|
| 39 |
+
def __init__(self, **kwargs):
|
| 40 |
+
self.prompts: List[str]
|
| 41 |
+
self.sampler = kwargs.get('sampler', 'ddpm')
|
| 42 |
+
self.neg = kwargs.get('neg', '')
|
| 43 |
+
self.seed = kwargs.get('seed', -1)
|
| 44 |
+
self.walk_seed = kwargs.get('walk_seed', False)
|
| 45 |
+
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
| 46 |
+
self.sample_steps = kwargs.get('sample_steps', 20)
|
| 47 |
+
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
| 48 |
+
self.ext = kwargs.get('ext', 'png')
|
| 49 |
+
self.denoise_strength = kwargs.get('denoise_strength', 0.5)
|
| 50 |
+
self.trigger_word = kwargs.get('trigger_word', None)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Img2ImgGenerator(BaseExtensionProcess):
|
| 54 |
+
|
| 55 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
| 56 |
+
super().__init__(process_id, job, config)
|
| 57 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
| 58 |
+
self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
|
| 59 |
+
self.device = self.get_conf('device', 'cuda')
|
| 60 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
| 61 |
+
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
| 62 |
+
self.is_latents_cached = True
|
| 63 |
+
raw_datasets = self.get_conf('datasets', None)
|
| 64 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
| 65 |
+
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
| 66 |
+
self.datasets = None
|
| 67 |
+
self.datasets_reg = None
|
| 68 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
| 69 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
| 70 |
+
self.params = []
|
| 71 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
| 72 |
+
for raw_dataset in raw_datasets:
|
| 73 |
+
dataset = DatasetConfig(**raw_dataset)
|
| 74 |
+
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
| 75 |
+
if not is_caching:
|
| 76 |
+
self.is_latents_cached = False
|
| 77 |
+
if dataset.is_reg:
|
| 78 |
+
if self.datasets_reg is None:
|
| 79 |
+
self.datasets_reg = []
|
| 80 |
+
self.datasets_reg.append(dataset)
|
| 81 |
+
else:
|
| 82 |
+
if self.datasets is None:
|
| 83 |
+
self.datasets = []
|
| 84 |
+
self.datasets.append(dataset)
|
| 85 |
+
|
| 86 |
+
self.progress_bar = None
|
| 87 |
+
self.sd = StableDiffusion(
|
| 88 |
+
device=self.device,
|
| 89 |
+
model_config=self.model_config,
|
| 90 |
+
dtype=self.dtype,
|
| 91 |
+
)
|
| 92 |
+
print(f"Using device {self.device}")
|
| 93 |
+
self.data_loader: DataLoader = None
|
| 94 |
+
self.adapter: T2IAdapter = None
|
| 95 |
+
|
| 96 |
+
def to_pil(self, img):
|
| 97 |
+
# image comes in -1 to 1. convert to a PIL RGB image
|
| 98 |
+
img = (img + 1) / 2
|
| 99 |
+
img = img.clamp(0, 1)
|
| 100 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
| 101 |
+
img = (img * 255).astype(np.uint8)
|
| 102 |
+
image = Image.fromarray(img)
|
| 103 |
+
return image
|
| 104 |
+
|
| 105 |
+
def run(self):
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
super().run()
|
| 108 |
+
print("Loading model...")
|
| 109 |
+
self.sd.load_model()
|
| 110 |
+
device = torch.device(self.device)
|
| 111 |
+
|
| 112 |
+
if self.model_config.is_xl:
|
| 113 |
+
pipe = StableDiffusionXLImg2ImgPipeline(
|
| 114 |
+
vae=self.sd.vae,
|
| 115 |
+
unet=self.sd.unet,
|
| 116 |
+
text_encoder=self.sd.text_encoder[0],
|
| 117 |
+
text_encoder_2=self.sd.text_encoder[1],
|
| 118 |
+
tokenizer=self.sd.tokenizer[0],
|
| 119 |
+
tokenizer_2=self.sd.tokenizer[1],
|
| 120 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
| 121 |
+
).to(device, dtype=self.torch_dtype)
|
| 122 |
+
elif self.model_config.is_pixart:
|
| 123 |
+
pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
|
| 124 |
+
else:
|
| 125 |
+
raise NotImplementedError("Only XL models are supported")
|
| 126 |
+
pipe.set_progress_bar_config(disable=True)
|
| 127 |
+
|
| 128 |
+
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 129 |
+
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
| 130 |
+
|
| 131 |
+
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
| 132 |
+
|
| 133 |
+
num_batches = len(self.data_loader)
|
| 134 |
+
pbar = tqdm(total=num_batches, desc="Generating images")
|
| 135 |
+
seed = self.generate_config.seed
|
| 136 |
+
# load images from datasets, use tqdm
|
| 137 |
+
for i, batch in enumerate(self.data_loader):
|
| 138 |
+
batch: DataLoaderBatchDTO = batch
|
| 139 |
+
|
| 140 |
+
gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
|
| 141 |
+
generator = torch.manual_seed(gen_seed)
|
| 142 |
+
|
| 143 |
+
file_item: FileItemDTO = batch.file_items[0]
|
| 144 |
+
img_path = file_item.path
|
| 145 |
+
img_filename = os.path.basename(img_path)
|
| 146 |
+
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
| 147 |
+
img_filename = img_filename_no_ext + '.' + self.generate_config.ext
|
| 148 |
+
output_path = os.path.join(self.output_folder, img_filename)
|
| 149 |
+
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
| 150 |
+
|
| 151 |
+
if self.copy_inputs_to is not None:
|
| 152 |
+
output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
|
| 153 |
+
output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
|
| 154 |
+
else:
|
| 155 |
+
output_inputs_path = None
|
| 156 |
+
output_inputs_caption_path = None
|
| 157 |
+
|
| 158 |
+
caption = batch.get_caption_list()[0]
|
| 159 |
+
if self.generate_config.trigger_word is not None:
|
| 160 |
+
caption = caption.replace('[trigger]', self.generate_config.trigger_word)
|
| 161 |
+
|
| 162 |
+
img: torch.Tensor = batch.tensor.clone()
|
| 163 |
+
image = self.to_pil(img)
|
| 164 |
+
|
| 165 |
+
# image.save(output_depth_path)
|
| 166 |
+
if self.model_config.is_pixart:
|
| 167 |
+
pipe: PixArtSigmaPipeline = pipe
|
| 168 |
+
|
| 169 |
+
# Encode the full image once
|
| 170 |
+
encoded_image = pipe.vae.encode(
|
| 171 |
+
pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
|
| 172 |
+
if hasattr(encoded_image, "latent_dist"):
|
| 173 |
+
latents = encoded_image.latent_dist.sample(generator)
|
| 174 |
+
elif hasattr(encoded_image, "latents"):
|
| 175 |
+
latents = encoded_image.latents
|
| 176 |
+
else:
|
| 177 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 178 |
+
latents = pipe.vae.config.scaling_factor * latents
|
| 179 |
+
|
| 180 |
+
# latents = self.sd.encode_images(img)
|
| 181 |
+
|
| 182 |
+
# self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
|
| 183 |
+
# start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
|
| 184 |
+
# timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
|
| 185 |
+
# timestep = timestep.to(device, dtype=torch.int32)
|
| 186 |
+
# latent = latent.to(device, dtype=self.torch_dtype)
|
| 187 |
+
# noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
|
| 188 |
+
# latent = self.sd.add_noise(latent, noise, timestep)
|
| 189 |
+
# timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
|
| 190 |
+
batch_size = 1
|
| 191 |
+
num_images_per_prompt = 1
|
| 192 |
+
|
| 193 |
+
shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
|
| 194 |
+
image.width // pipe.vae_scale_factor)
|
| 195 |
+
noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
|
| 196 |
+
|
| 197 |
+
# noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
|
| 198 |
+
num_inference_steps = self.generate_config.sample_steps
|
| 199 |
+
strength = self.generate_config.denoise_strength
|
| 200 |
+
# Get timesteps
|
| 201 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 202 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 203 |
+
pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
|
| 204 |
+
timesteps = pipe.scheduler.timesteps[t_start:]
|
| 205 |
+
timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 206 |
+
latents = pipe.scheduler.add_noise(latents, noise, timestep)
|
| 207 |
+
|
| 208 |
+
gen_images = pipe.__call__(
|
| 209 |
+
prompt=caption,
|
| 210 |
+
negative_prompt=self.generate_config.neg,
|
| 211 |
+
latents=latents,
|
| 212 |
+
timesteps=timesteps,
|
| 213 |
+
width=image.width,
|
| 214 |
+
height=image.height,
|
| 215 |
+
num_inference_steps=num_inference_steps,
|
| 216 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 217 |
+
guidance_scale=self.generate_config.guidance_scale,
|
| 218 |
+
# strength=self.generate_config.denoise_strength,
|
| 219 |
+
use_resolution_binning=False,
|
| 220 |
+
output_type="np"
|
| 221 |
+
).images[0]
|
| 222 |
+
gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
|
| 223 |
+
gen_images = Image.fromarray(gen_images)
|
| 224 |
+
else:
|
| 225 |
+
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
| 226 |
+
|
| 227 |
+
gen_images = pipe.__call__(
|
| 228 |
+
prompt=caption,
|
| 229 |
+
negative_prompt=self.generate_config.neg,
|
| 230 |
+
image=image,
|
| 231 |
+
num_inference_steps=self.generate_config.sample_steps,
|
| 232 |
+
guidance_scale=self.generate_config.guidance_scale,
|
| 233 |
+
strength=self.generate_config.denoise_strength,
|
| 234 |
+
).images[0]
|
| 235 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 236 |
+
gen_images.save(output_path)
|
| 237 |
+
|
| 238 |
+
# save caption
|
| 239 |
+
with open(output_caption_path, 'w') as f:
|
| 240 |
+
f.write(caption)
|
| 241 |
+
|
| 242 |
+
if output_inputs_path is not None:
|
| 243 |
+
os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
|
| 244 |
+
image.save(output_inputs_path)
|
| 245 |
+
with open(output_inputs_caption_path, 'w') as f:
|
| 246 |
+
f.write(caption)
|
| 247 |
+
|
| 248 |
+
pbar.update(1)
|
| 249 |
+
batch.cleanup()
|
| 250 |
+
|
| 251 |
+
pbar.close()
|
| 252 |
+
print("Done generating images")
|
| 253 |
+
# cleanup
|
| 254 |
+
del self.sd
|
| 255 |
+
gc.collect()
|
| 256 |
+
torch.cuda.empty_cache()
|
ai-toolkit/extensions_built_in/advanced_generator/PureLoraGenerator.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
|
| 5 |
+
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
|
| 6 |
+
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
| 7 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
| 8 |
+
import gc
|
| 9 |
+
import torch
|
| 10 |
+
from jobs.process import BaseExtensionProcess
|
| 11 |
+
from toolkit.train_tools import get_torch_dtype
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def flush():
|
| 15 |
+
torch.cuda.empty_cache()
|
| 16 |
+
gc.collect()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PureLoraGenerator(BaseExtensionProcess):
|
| 20 |
+
|
| 21 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
| 22 |
+
super().__init__(process_id, job, config)
|
| 23 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
| 24 |
+
self.device = self.get_conf('device', 'cuda')
|
| 25 |
+
self.device_torch = torch.device(self.device)
|
| 26 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
| 27 |
+
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
|
| 28 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
| 29 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
| 30 |
+
lorm_config = self.get_conf('lorm', None)
|
| 31 |
+
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
|
| 32 |
+
|
| 33 |
+
self.device_state_preset = get_train_sd_device_state_preset(
|
| 34 |
+
device=torch.device(self.device),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.progress_bar = None
|
| 38 |
+
self.sd = StableDiffusion(
|
| 39 |
+
device=self.device,
|
| 40 |
+
model_config=self.model_config,
|
| 41 |
+
dtype=self.dtype,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def run(self):
|
| 45 |
+
super().run()
|
| 46 |
+
print("Loading model...")
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
self.sd.load_model()
|
| 49 |
+
self.sd.unet.eval()
|
| 50 |
+
self.sd.unet.to(self.device_torch)
|
| 51 |
+
if isinstance(self.sd.text_encoder, list):
|
| 52 |
+
for te in self.sd.text_encoder:
|
| 53 |
+
te.eval()
|
| 54 |
+
te.to(self.device_torch)
|
| 55 |
+
else:
|
| 56 |
+
self.sd.text_encoder.eval()
|
| 57 |
+
self.sd.to(self.device_torch)
|
| 58 |
+
|
| 59 |
+
print(f"Converting to LoRM UNet")
|
| 60 |
+
# replace the unet with LoRMUnet
|
| 61 |
+
convert_diffusers_unet_to_lorm(
|
| 62 |
+
self.sd.unet,
|
| 63 |
+
config=self.lorm_config,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
sample_folder = os.path.join(self.output_folder)
|
| 67 |
+
gen_img_config_list = []
|
| 68 |
+
|
| 69 |
+
sample_config = self.generate_config
|
| 70 |
+
start_seed = sample_config.seed
|
| 71 |
+
current_seed = start_seed
|
| 72 |
+
for i in range(len(sample_config.prompts)):
|
| 73 |
+
if sample_config.walk_seed:
|
| 74 |
+
current_seed = start_seed + i
|
| 75 |
+
|
| 76 |
+
filename = f"[time]_[count].{self.generate_config.ext}"
|
| 77 |
+
output_path = os.path.join(sample_folder, filename)
|
| 78 |
+
prompt = sample_config.prompts[i]
|
| 79 |
+
extra_args = {}
|
| 80 |
+
gen_img_config_list.append(GenerateImageConfig(
|
| 81 |
+
prompt=prompt, # it will autoparse the prompt
|
| 82 |
+
width=sample_config.width,
|
| 83 |
+
height=sample_config.height,
|
| 84 |
+
negative_prompt=sample_config.neg,
|
| 85 |
+
seed=current_seed,
|
| 86 |
+
guidance_scale=sample_config.guidance_scale,
|
| 87 |
+
guidance_rescale=sample_config.guidance_rescale,
|
| 88 |
+
num_inference_steps=sample_config.sample_steps,
|
| 89 |
+
network_multiplier=sample_config.network_multiplier,
|
| 90 |
+
output_path=output_path,
|
| 91 |
+
output_ext=sample_config.ext,
|
| 92 |
+
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
| 93 |
+
**extra_args
|
| 94 |
+
))
|
| 95 |
+
|
| 96 |
+
# send to be generated
|
| 97 |
+
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
| 98 |
+
print("Done generating images")
|
| 99 |
+
# cleanup
|
| 100 |
+
del self.sd
|
| 101 |
+
gc.collect()
|
| 102 |
+
torch.cuda.empty_cache()
|
ai-toolkit/extensions_built_in/advanced_generator/ReferenceGenerator.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from diffusers import T2IAdapter
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
| 14 |
+
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
| 15 |
+
from toolkit.sampler import get_sampler
|
| 16 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
| 17 |
+
import gc
|
| 18 |
+
import torch
|
| 19 |
+
from jobs.process import BaseExtensionProcess
|
| 20 |
+
from toolkit.data_loader import get_dataloader_from_datasets
|
| 21 |
+
from toolkit.train_tools import get_torch_dtype
|
| 22 |
+
from controlnet_aux.midas import MidasDetector
|
| 23 |
+
from diffusers.utils import load_image
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def flush():
|
| 27 |
+
torch.cuda.empty_cache()
|
| 28 |
+
gc.collect()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GenerateConfig:
|
| 32 |
+
|
| 33 |
+
def __init__(self, **kwargs):
|
| 34 |
+
self.prompts: List[str]
|
| 35 |
+
self.sampler = kwargs.get('sampler', 'ddpm')
|
| 36 |
+
self.neg = kwargs.get('neg', '')
|
| 37 |
+
self.seed = kwargs.get('seed', -1)
|
| 38 |
+
self.walk_seed = kwargs.get('walk_seed', False)
|
| 39 |
+
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
|
| 40 |
+
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
| 41 |
+
self.sample_steps = kwargs.get('sample_steps', 20)
|
| 42 |
+
self.prompt_2 = kwargs.get('prompt_2', None)
|
| 43 |
+
self.neg_2 = kwargs.get('neg_2', None)
|
| 44 |
+
self.prompts = kwargs.get('prompts', None)
|
| 45 |
+
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
| 46 |
+
self.ext = kwargs.get('ext', 'png')
|
| 47 |
+
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
| 48 |
+
if kwargs.get('shuffle', False):
|
| 49 |
+
# shuffle the prompts
|
| 50 |
+
random.shuffle(self.prompts)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ReferenceGenerator(BaseExtensionProcess):
|
| 54 |
+
|
| 55 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
| 56 |
+
super().__init__(process_id, job, config)
|
| 57 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
| 58 |
+
self.device = self.get_conf('device', 'cuda')
|
| 59 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
| 60 |
+
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
| 61 |
+
self.is_latents_cached = True
|
| 62 |
+
raw_datasets = self.get_conf('datasets', None)
|
| 63 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
| 64 |
+
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
| 65 |
+
self.datasets = None
|
| 66 |
+
self.datasets_reg = None
|
| 67 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
| 68 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
| 69 |
+
self.params = []
|
| 70 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
| 71 |
+
for raw_dataset in raw_datasets:
|
| 72 |
+
dataset = DatasetConfig(**raw_dataset)
|
| 73 |
+
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
| 74 |
+
if not is_caching:
|
| 75 |
+
self.is_latents_cached = False
|
| 76 |
+
if dataset.is_reg:
|
| 77 |
+
if self.datasets_reg is None:
|
| 78 |
+
self.datasets_reg = []
|
| 79 |
+
self.datasets_reg.append(dataset)
|
| 80 |
+
else:
|
| 81 |
+
if self.datasets is None:
|
| 82 |
+
self.datasets = []
|
| 83 |
+
self.datasets.append(dataset)
|
| 84 |
+
|
| 85 |
+
self.progress_bar = None
|
| 86 |
+
self.sd = StableDiffusion(
|
| 87 |
+
device=self.device,
|
| 88 |
+
model_config=self.model_config,
|
| 89 |
+
dtype=self.dtype,
|
| 90 |
+
)
|
| 91 |
+
print(f"Using device {self.device}")
|
| 92 |
+
self.data_loader: DataLoader = None
|
| 93 |
+
self.adapter: T2IAdapter = None
|
| 94 |
+
|
| 95 |
+
def run(self):
|
| 96 |
+
super().run()
|
| 97 |
+
print("Loading model...")
|
| 98 |
+
self.sd.load_model()
|
| 99 |
+
device = torch.device(self.device)
|
| 100 |
+
|
| 101 |
+
if self.generate_config.t2i_adapter_path is not None:
|
| 102 |
+
self.adapter = T2IAdapter.from_pretrained(
|
| 103 |
+
self.generate_config.t2i_adapter_path,
|
| 104 |
+
torch_dtype=self.torch_dtype,
|
| 105 |
+
varient="fp16"
|
| 106 |
+
).to(device)
|
| 107 |
+
|
| 108 |
+
midas_depth = MidasDetector.from_pretrained(
|
| 109 |
+
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
| 110 |
+
).to(device)
|
| 111 |
+
|
| 112 |
+
if self.model_config.is_xl:
|
| 113 |
+
pipe = StableDiffusionXLAdapterPipeline(
|
| 114 |
+
vae=self.sd.vae,
|
| 115 |
+
unet=self.sd.unet,
|
| 116 |
+
text_encoder=self.sd.text_encoder[0],
|
| 117 |
+
text_encoder_2=self.sd.text_encoder[1],
|
| 118 |
+
tokenizer=self.sd.tokenizer[0],
|
| 119 |
+
tokenizer_2=self.sd.tokenizer[1],
|
| 120 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
| 121 |
+
adapter=self.adapter,
|
| 122 |
+
).to(device, dtype=self.torch_dtype)
|
| 123 |
+
else:
|
| 124 |
+
pipe = StableDiffusionAdapterPipeline(
|
| 125 |
+
vae=self.sd.vae,
|
| 126 |
+
unet=self.sd.unet,
|
| 127 |
+
text_encoder=self.sd.text_encoder,
|
| 128 |
+
tokenizer=self.sd.tokenizer,
|
| 129 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
| 130 |
+
safety_checker=None,
|
| 131 |
+
feature_extractor=None,
|
| 132 |
+
requires_safety_checker=False,
|
| 133 |
+
adapter=self.adapter,
|
| 134 |
+
).to(device, dtype=self.torch_dtype)
|
| 135 |
+
pipe.set_progress_bar_config(disable=True)
|
| 136 |
+
|
| 137 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 138 |
+
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
| 139 |
+
|
| 140 |
+
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
| 141 |
+
|
| 142 |
+
num_batches = len(self.data_loader)
|
| 143 |
+
pbar = tqdm(total=num_batches, desc="Generating images")
|
| 144 |
+
seed = self.generate_config.seed
|
| 145 |
+
# load images from datasets, use tqdm
|
| 146 |
+
for i, batch in enumerate(self.data_loader):
|
| 147 |
+
batch: DataLoaderBatchDTO = batch
|
| 148 |
+
|
| 149 |
+
file_item: FileItemDTO = batch.file_items[0]
|
| 150 |
+
img_path = file_item.path
|
| 151 |
+
img_filename = os.path.basename(img_path)
|
| 152 |
+
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
| 153 |
+
output_path = os.path.join(self.output_folder, img_filename)
|
| 154 |
+
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
| 155 |
+
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
|
| 156 |
+
|
| 157 |
+
caption = batch.get_caption_list()[0]
|
| 158 |
+
|
| 159 |
+
img: torch.Tensor = batch.tensor.clone()
|
| 160 |
+
# image comes in -1 to 1. convert to a PIL RGB image
|
| 161 |
+
img = (img + 1) / 2
|
| 162 |
+
img = img.clamp(0, 1)
|
| 163 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
| 164 |
+
img = (img * 255).astype(np.uint8)
|
| 165 |
+
image = Image.fromarray(img)
|
| 166 |
+
|
| 167 |
+
width, height = image.size
|
| 168 |
+
min_res = min(width, height)
|
| 169 |
+
|
| 170 |
+
if self.generate_config.walk_seed:
|
| 171 |
+
seed = seed + 1
|
| 172 |
+
|
| 173 |
+
if self.generate_config.seed == -1:
|
| 174 |
+
# random
|
| 175 |
+
seed = random.randint(0, 1000000)
|
| 176 |
+
|
| 177 |
+
torch.manual_seed(seed)
|
| 178 |
+
torch.cuda.manual_seed(seed)
|
| 179 |
+
|
| 180 |
+
# generate depth map
|
| 181 |
+
image = midas_depth(
|
| 182 |
+
image,
|
| 183 |
+
detect_resolution=min_res, # do 512 ?
|
| 184 |
+
image_resolution=min_res
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# image.save(output_depth_path)
|
| 188 |
+
|
| 189 |
+
gen_images = pipe(
|
| 190 |
+
prompt=caption,
|
| 191 |
+
negative_prompt=self.generate_config.neg,
|
| 192 |
+
image=image,
|
| 193 |
+
num_inference_steps=self.generate_config.sample_steps,
|
| 194 |
+
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
| 195 |
+
guidance_scale=self.generate_config.guidance_scale,
|
| 196 |
+
).images[0]
|
| 197 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 198 |
+
gen_images.save(output_path)
|
| 199 |
+
|
| 200 |
+
# save caption
|
| 201 |
+
with open(output_caption_path, 'w') as f:
|
| 202 |
+
f.write(caption)
|
| 203 |
+
|
| 204 |
+
pbar.update(1)
|
| 205 |
+
batch.cleanup()
|
| 206 |
+
|
| 207 |
+
pbar.close()
|
| 208 |
+
print("Done generating images")
|
| 209 |
+
# cleanup
|
| 210 |
+
del self.sd
|
| 211 |
+
gc.collect()
|
| 212 |
+
torch.cuda.empty_cache()
|
ai-toolkit/extensions_built_in/advanced_generator/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
| 2 |
+
from toolkit.extension import Extension
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
| 6 |
+
class AdvancedReferenceGeneratorExtension(Extension):
|
| 7 |
+
# uid must be unique, it is how the extension is identified
|
| 8 |
+
uid = "reference_generator"
|
| 9 |
+
|
| 10 |
+
# name is the name of the extension for printing
|
| 11 |
+
name = "Reference Generator"
|
| 12 |
+
|
| 13 |
+
# This is where your process class is loaded
|
| 14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
| 15 |
+
@classmethod
|
| 16 |
+
def get_process(cls):
|
| 17 |
+
# import your process class here so it is only loaded when needed and return it
|
| 18 |
+
from .ReferenceGenerator import ReferenceGenerator
|
| 19 |
+
return ReferenceGenerator
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
| 23 |
+
class PureLoraGenerator(Extension):
|
| 24 |
+
# uid must be unique, it is how the extension is identified
|
| 25 |
+
uid = "pure_lora_generator"
|
| 26 |
+
|
| 27 |
+
# name is the name of the extension for printing
|
| 28 |
+
name = "Pure LoRA Generator"
|
| 29 |
+
|
| 30 |
+
# This is where your process class is loaded
|
| 31 |
+
# keep your imports in here so they don't slow down the rest of the program
|
| 32 |
+
@classmethod
|
| 33 |
+
def get_process(cls):
|
| 34 |
+
# import your process class here so it is only loaded when needed and return it
|
| 35 |
+
from .PureLoraGenerator import PureLoraGenerator
|
| 36 |
+
return PureLoraGenerator
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
| 40 |
+
class Img2ImgGeneratorExtension(Extension):
|
| 41 |
+
# uid must be unique, it is how the extension is identified
|
| 42 |
+
uid = "batch_img2img"
|
| 43 |
+
|
| 44 |
+
# name is the name of the extension for printing
|
| 45 |
+
name = "Img2ImgGeneratorExtension"
|
| 46 |
+
|
| 47 |
+
# This is where your process class is loaded
|
| 48 |
+
# keep your imports in here so they don't slow down the rest of the program
|
| 49 |
+
@classmethod
|
| 50 |
+
def get_process(cls):
|
| 51 |
+
# import your process class here so it is only loaded when needed and return it
|
| 52 |
+
from .Img2ImgGenerator import Img2ImgGenerator
|
| 53 |
+
return Img2ImgGenerator
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
AI_TOOLKIT_EXTENSIONS = [
|
| 57 |
+
# you can put a list of extensions here
|
| 58 |
+
AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
|
| 59 |
+
]
|
ai-toolkit/extensions_built_in/advanced_generator/config/train.example.yaml
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
job: extension
|
| 3 |
+
config:
|
| 4 |
+
name: test_v1
|
| 5 |
+
process:
|
| 6 |
+
- type: 'textual_inversion_trainer'
|
| 7 |
+
training_folder: "out/TI"
|
| 8 |
+
device: cuda:0
|
| 9 |
+
# for tensorboard logging
|
| 10 |
+
log_dir: "out/.tensorboard"
|
| 11 |
+
embedding:
|
| 12 |
+
trigger: "your_trigger_here"
|
| 13 |
+
tokens: 12
|
| 14 |
+
init_words: "man with short brown hair"
|
| 15 |
+
save_format: "safetensors" # 'safetensors' or 'pt'
|
| 16 |
+
save:
|
| 17 |
+
dtype: float16 # precision to save
|
| 18 |
+
save_every: 100 # save every this many steps
|
| 19 |
+
max_step_saves_to_keep: 5 # only affects step counts
|
| 20 |
+
datasets:
|
| 21 |
+
- folder_path: "/path/to/dataset"
|
| 22 |
+
caption_ext: "txt"
|
| 23 |
+
default_caption: "[trigger]"
|
| 24 |
+
buckets: true
|
| 25 |
+
resolution: 512
|
| 26 |
+
train:
|
| 27 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
| 28 |
+
steps: 3000
|
| 29 |
+
weight_jitter: 0.0
|
| 30 |
+
lr: 5e-5
|
| 31 |
+
train_unet: false
|
| 32 |
+
gradient_checkpointing: true
|
| 33 |
+
train_text_encoder: false
|
| 34 |
+
optimizer: "adamw"
|
| 35 |
+
# optimizer: "prodigy"
|
| 36 |
+
optimizer_params:
|
| 37 |
+
weight_decay: 1e-2
|
| 38 |
+
lr_scheduler: "constant"
|
| 39 |
+
max_denoising_steps: 1000
|
| 40 |
+
batch_size: 4
|
| 41 |
+
dtype: bf16
|
| 42 |
+
xformers: true
|
| 43 |
+
min_snr_gamma: 5.0
|
| 44 |
+
# skip_first_sample: true
|
| 45 |
+
noise_offset: 0.0 # not needed for this
|
| 46 |
+
model:
|
| 47 |
+
# objective reality v2
|
| 48 |
+
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
| 49 |
+
is_v2: false # for v2 models
|
| 50 |
+
is_xl: false # for SDXL models
|
| 51 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
| 52 |
+
sample:
|
| 53 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
| 54 |
+
sample_every: 100 # sample every this many steps
|
| 55 |
+
width: 512
|
| 56 |
+
height: 512
|
| 57 |
+
prompts:
|
| 58 |
+
- "photo of [trigger] laughing"
|
| 59 |
+
- "photo of [trigger] smiling"
|
| 60 |
+
- "[trigger] close up"
|
| 61 |
+
- "dark scene [trigger] frozen"
|
| 62 |
+
- "[trigger] nighttime"
|
| 63 |
+
- "a painting of [trigger]"
|
| 64 |
+
- "a drawing of [trigger]"
|
| 65 |
+
- "a cartoon of [trigger]"
|
| 66 |
+
- "[trigger] pixar style"
|
| 67 |
+
- "[trigger] costume"
|
| 68 |
+
neg: ""
|
| 69 |
+
seed: 42
|
| 70 |
+
walk_seed: false
|
| 71 |
+
guidance_scale: 7
|
| 72 |
+
sample_steps: 20
|
| 73 |
+
network_multiplier: 1.0
|
| 74 |
+
|
| 75 |
+
logging:
|
| 76 |
+
log_every: 10 # log every this many steps
|
| 77 |
+
use_wandb: false # not supported yet
|
| 78 |
+
verbose: false
|
| 79 |
+
|
| 80 |
+
# You can put any information you want here, and it will be saved in the model.
|
| 81 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
| 82 |
+
# It is saved in the model so be aware of that. The software will include this
|
| 83 |
+
# plus some other information for you automatically
|
| 84 |
+
meta:
|
| 85 |
+
# [name] gets replaced with the name above
|
| 86 |
+
name: "[name]"
|
| 87 |
+
# version: '1.0'
|
| 88 |
+
# creator:
|
| 89 |
+
# name: Your Name
|
| 90 |
+
# email: your@gmail.com
|
| 91 |
+
# website: https://your.website
|
ai-toolkit/extensions_built_in/audio_models/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ace_step import AceStep15Model, AceStep15XLModel
|
| 2 |
+
|
| 3 |
+
AI_TOOLKIT_MODELS = [
|
| 4 |
+
# put a list of models here
|
| 5 |
+
AceStep15Model,
|
| 6 |
+
AceStep15XLModel,
|
| 7 |
+
]
|
ai-toolkit/extensions_built_in/audio_models/ace_step/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .ace_step_15_model import AceStep15Model, AceStep15XLModel
|
ai-toolkit/extensions_built_in/audio_models/ace_step/ace_step_15_model.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import huggingface_hub
|
| 5 |
+
import torch
|
| 6 |
+
from safetensors.torch import load_file, save_file
|
| 7 |
+
from extensions_built_in.audio_models.base_audio_model import BaseAudioModel
|
| 8 |
+
from toolkit.basic import flush
|
| 9 |
+
from toolkit.config_modules import GenerateImageConfig
|
| 10 |
+
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
| 11 |
+
from toolkit.samplers.custom_flowmatch_sampler import (
|
| 12 |
+
CustomFlowMatchEulerDiscreteScheduler,
|
| 13 |
+
)
|
| 14 |
+
from toolkit.util.quantize import get_qtype, quantize, quantize_model
|
| 15 |
+
|
| 16 |
+
from optimum.quanto import freeze
|
| 17 |
+
from .src.model import (
|
| 18 |
+
AceStep15,
|
| 19 |
+
OobleckVAE,
|
| 20 |
+
TextEncoder,
|
| 21 |
+
get_silence_latent,
|
| 22 |
+
load_models,
|
| 23 |
+
)
|
| 24 |
+
from transformers import AutoTokenizer
|
| 25 |
+
from .src.pipeline import AceStep15Pipeline
|
| 26 |
+
|
| 27 |
+
scheduler_config = {
|
| 28 |
+
"num_train_timesteps": 1000,
|
| 29 |
+
"shift": 3.0,
|
| 30 |
+
"use_dynamic_shifting": False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def to_number(str_or_number, default):
|
| 34 |
+
if isinstance(str_or_number, (int, float)):
|
| 35 |
+
return str_or_number
|
| 36 |
+
if str_or_number is None:
|
| 37 |
+
return default
|
| 38 |
+
if str_or_number == "":
|
| 39 |
+
return default
|
| 40 |
+
try:
|
| 41 |
+
return float(str_or_number)
|
| 42 |
+
except ValueError:
|
| 43 |
+
try:
|
| 44 |
+
return int(str_or_number)
|
| 45 |
+
except ValueError as e:
|
| 46 |
+
raise ValueError(f"Could not convert {str_or_number} to a number") from e
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_ace_step_caption(text):
|
| 50 |
+
"""Parse a tagged caption file back into a dict."""
|
| 51 |
+
import re
|
| 52 |
+
|
| 53 |
+
def tag(name):
|
| 54 |
+
m = re.search(rf"<{name}>(.*?)</{name}>", text, re.DOTALL)
|
| 55 |
+
return m.group(1).strip() if m else ""
|
| 56 |
+
|
| 57 |
+
return {
|
| 58 |
+
"caption": tag("CAPTION"),
|
| 59 |
+
"lyrics": tag("LYRICS"),
|
| 60 |
+
"bpm": to_number(tag("BPM"), 120),
|
| 61 |
+
"keyscale": tag("KEYSCALE"),
|
| 62 |
+
"timesignature": tag("TIMESIGNATURE"),
|
| 63 |
+
"duration": to_number(tag("DURATION"), 1.0),
|
| 64 |
+
"language": tag("LANGUAGE"),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AceStep15Model(BaseAudioModel):
|
| 69 |
+
arch = "ace_step_15"
|
| 70 |
+
sample_rate = 48000
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
device,
|
| 75 |
+
model_config,
|
| 76 |
+
dtype="bf16",
|
| 77 |
+
custom_pipeline=None,
|
| 78 |
+
noise_scheduler=None,
|
| 79 |
+
**kwargs,
|
| 80 |
+
):
|
| 81 |
+
super().__init__(
|
| 82 |
+
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
| 83 |
+
)
|
| 84 |
+
self.is_flow_matching = True
|
| 85 |
+
self.is_transformer = True
|
| 86 |
+
# self.target_lora_modules = ['AceStep15']
|
| 87 |
+
self.target_lora_modules = ["DiTModel"]
|
| 88 |
+
|
| 89 |
+
# static method to get the noise scheduler
|
| 90 |
+
@staticmethod
|
| 91 |
+
def get_train_scheduler():
|
| 92 |
+
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
| 93 |
+
|
| 94 |
+
def load_model(self):
|
| 95 |
+
dtype = self.torch_dtype
|
| 96 |
+
device = self.device_torch
|
| 97 |
+
|
| 98 |
+
model_path = self.model_config.name_or_path
|
| 99 |
+
|
| 100 |
+
if not os.path.exists(model_path):
|
| 101 |
+
# assume it is a hf repo like org/repo/filename.safetensors
|
| 102 |
+
path_parts = model_path.split("/")
|
| 103 |
+
if len(path_parts) != 3:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
f"Model path {model_path} does not exist and is not a valid Hugging Face repo path"
|
| 106 |
+
)
|
| 107 |
+
model_path = huggingface_hub.hf_hub_download(
|
| 108 |
+
repo_id=f"{path_parts[0]}/{path_parts[1]}",
|
| 109 |
+
filename=path_parts[2],
|
| 110 |
+
)
|
| 111 |
+
# load the models from the single safetensors file
|
| 112 |
+
load_device = device
|
| 113 |
+
if self.model_config.low_vram:
|
| 114 |
+
load_device = "cpu"
|
| 115 |
+
|
| 116 |
+
models = load_models(model_path, device=load_device, dtype=dtype)
|
| 117 |
+
|
| 118 |
+
self.model = models["model"]
|
| 119 |
+
|
| 120 |
+
if self.model_config.quantize:
|
| 121 |
+
self.print_and_status_update("Quantizing Transformer")
|
| 122 |
+
# quantize_model(self, self.model.decoder)
|
| 123 |
+
quantize(self.model, weights=get_qtype(self.model_config.qtype))
|
| 124 |
+
freeze(self.model)
|
| 125 |
+
flush()
|
| 126 |
+
|
| 127 |
+
if self.model_config.low_vram:
|
| 128 |
+
self.print_and_status_update("Moving transformer to CPU")
|
| 129 |
+
self.model.to("cpu")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if (
|
| 133 |
+
self.model_config.layer_offloading
|
| 134 |
+
and self.model_config.layer_offloading_transformer_percent > 0
|
| 135 |
+
):
|
| 136 |
+
raise NotImplementedError("Layer offloading not yet implemented for AceStep15Model")
|
| 137 |
+
|
| 138 |
+
self.text_encoder = models["text_encoder"]
|
| 139 |
+
|
| 140 |
+
if self.model_config.quantize_te:
|
| 141 |
+
self.print_and_status_update("Quantizing Text Encoder")
|
| 142 |
+
quantize(self.text_encoder, weights=get_qtype(self.model_config.qtype_te))
|
| 143 |
+
freeze(self.text_encoder)
|
| 144 |
+
flush()
|
| 145 |
+
|
| 146 |
+
self.vae = models["vae"]
|
| 147 |
+
|
| 148 |
+
# move back to device
|
| 149 |
+
self.model.to(device)
|
| 150 |
+
self.text_encoder.to(device)
|
| 151 |
+
self.vae.to(device)
|
| 152 |
+
self.tokenizer = models["tokenizer"]
|
| 153 |
+
|
| 154 |
+
self.pipeline = AceStep15Pipeline(
|
| 155 |
+
transformer=self.model,
|
| 156 |
+
vae=self.vae,
|
| 157 |
+
text_encoder=self.text_encoder,
|
| 158 |
+
tokenizer=self.tokenizer,
|
| 159 |
+
scheduler=self.get_train_scheduler(),
|
| 160 |
+
)
|
| 161 |
+
if self.model_config.low_vram:
|
| 162 |
+
self.pipeline.do_tiled_decoding = True
|
| 163 |
+
|
| 164 |
+
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
| 165 |
+
if isinstance(prompt, str):
|
| 166 |
+
prompts = [prompt]
|
| 167 |
+
else:
|
| 168 |
+
prompts = prompt
|
| 169 |
+
|
| 170 |
+
if self.text_encoder.device == torch.device("cpu"):
|
| 171 |
+
self.text_encoder.to(self.device_torch)
|
| 172 |
+
# we need the encoder from the model
|
| 173 |
+
if self.model.encoder.device == torch.device("cpu"):
|
| 174 |
+
self.model.encoder.to(self.device_torch)
|
| 175 |
+
|
| 176 |
+
# the prompt should be json as a string. Try to parse it.
|
| 177 |
+
json_prompts = []
|
| 178 |
+
for p in prompts:
|
| 179 |
+
try:
|
| 180 |
+
json_prompts.append(parse_ace_step_caption(p))
|
| 181 |
+
except json.JSONDecodeError:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"Prompt {p} is not a valid JSON string. Prompts must be JSON for this model"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if self.pipeline.text_encoder.device == torch.device("cpu"):
|
| 187 |
+
self.pipeline.text_encoder.to(self.device_torch)
|
| 188 |
+
|
| 189 |
+
device = self.text_encoder.device
|
| 190 |
+
dtype = self.text_encoder.dtype
|
| 191 |
+
|
| 192 |
+
batch_pe = None
|
| 193 |
+
# TODO not sure this will allow for proper batching
|
| 194 |
+
|
| 195 |
+
for json_prompt in json_prompts:
|
| 196 |
+
prompt = json_prompt.get("caption", "")
|
| 197 |
+
lyrics = json_prompt.get("lyrics", "")
|
| 198 |
+
bpm = json_prompt.get("bpm", 120)
|
| 199 |
+
key = json_prompt.get("key", "C")
|
| 200 |
+
time_sig = json_prompt.get("time_sig", "4/4")
|
| 201 |
+
duration = json_prompt.get("duration", 10)
|
| 202 |
+
duration = int(duration) if isinstance(duration, (int, float)) else 10
|
| 203 |
+
language = json_prompt.get("language", "en")
|
| 204 |
+
|
| 205 |
+
text_embeddings, text_mask, lyric_embeddings, lyric_mask = (
|
| 206 |
+
self.pipeline.get_text_embedings(
|
| 207 |
+
prompt, lyrics, bpm, key, time_sig, duration, language
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
latent_len = int(duration * self.pipeline.LATENT_RATE)
|
| 211 |
+
# Silence as source latent [1, 64, T] -> [1, T, 64] for DiT
|
| 212 |
+
sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T]
|
| 213 |
+
src = sil.transpose(1, 2) # [1, T, 64]
|
| 214 |
+
chunk_masks = torch.ones_like(src)
|
| 215 |
+
|
| 216 |
+
# Reference audio (silence)
|
| 217 |
+
ref = sil[:, :, :750].transpose(1, 2) # [1, 750, 64]
|
| 218 |
+
ref_order = torch.zeros(1, device=device, dtype=torch.long)
|
| 219 |
+
enc_h, enc_m, _ = self.pipeline.transformer.prepare_condition(
|
| 220 |
+
text_embeddings,
|
| 221 |
+
text_mask,
|
| 222 |
+
lyric_embeddings,
|
| 223 |
+
lyric_mask,
|
| 224 |
+
ref,
|
| 225 |
+
ref_order,
|
| 226 |
+
src,
|
| 227 |
+
chunk_masks,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
pe = PromptEmbeds(enc_h, attention_mask=enc_m)
|
| 231 |
+
if batch_pe is None:
|
| 232 |
+
batch_pe = pe
|
| 233 |
+
else:
|
| 234 |
+
batch_pe = concat_prompt_embeds(batch_pe, pe)
|
| 235 |
+
return batch_pe
|
| 236 |
+
|
| 237 |
+
def get_transformer_block_names(self) -> Optional[List[str]]:
|
| 238 |
+
return ["layers"]
|
| 239 |
+
|
| 240 |
+
def get_generation_pipeline(self):
|
| 241 |
+
return self.pipeline
|
| 242 |
+
|
| 243 |
+
def generate_single_audio(
|
| 244 |
+
self,
|
| 245 |
+
pipeline,
|
| 246 |
+
gen_config: GenerateImageConfig,
|
| 247 |
+
conditional_embeds: PromptEmbeds,
|
| 248 |
+
unconditional_embeds: PromptEmbeds,
|
| 249 |
+
generator: torch.Generator,
|
| 250 |
+
extra: dict,
|
| 251 |
+
):
|
| 252 |
+
if self.model.device == torch.device("cpu"):
|
| 253 |
+
self.model.to(self.device_torch)
|
| 254 |
+
# make sure gen config is setup for audio
|
| 255 |
+
if gen_config.output_ext not in ['mp3', 'wav']:
|
| 256 |
+
gen_config.output_ext = 'mp3'
|
| 257 |
+
prompt = gen_config.prompt
|
| 258 |
+
json_prompt = parse_ace_step_caption(prompt)
|
| 259 |
+
prompt = json_prompt.get("caption", "")
|
| 260 |
+
lyrics = json_prompt.get("lyrics", "")
|
| 261 |
+
bpm = json_prompt.get("bpm", 120)
|
| 262 |
+
key = json_prompt.get("key", "C")
|
| 263 |
+
time_sig = json_prompt.get("time_sig", "4/4")
|
| 264 |
+
duration = json_prompt.get("duration", 0)
|
| 265 |
+
language = json_prompt.get("language", "en")
|
| 266 |
+
|
| 267 |
+
output = self.pipeline(
|
| 268 |
+
prompt=None, # we are passing in the embeds directly, so no need for a prompt
|
| 269 |
+
encoder_embeddings=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.torch_dtype),
|
| 270 |
+
encoder_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.bool),
|
| 271 |
+
num_inference_steps=gen_config.num_inference_steps,
|
| 272 |
+
duration=duration,
|
| 273 |
+
generator=generator,
|
| 274 |
+
bpm=bpm,
|
| 275 |
+
key=key,
|
| 276 |
+
time_sig=time_sig,
|
| 277 |
+
language=language,
|
| 278 |
+
guidance_scale=gen_config.guidance_scale,
|
| 279 |
+
)
|
| 280 |
+
return output
|
| 281 |
+
|
| 282 |
+
def get_noise_prediction(
|
| 283 |
+
self,
|
| 284 |
+
latent_model_input: torch.Tensor, #(1, 300, 64)
|
| 285 |
+
timestep: torch.Tensor, # 0 to 1000 scale
|
| 286 |
+
text_embeddings: PromptEmbeds,
|
| 287 |
+
**kwargs,
|
| 288 |
+
):
|
| 289 |
+
if self.model.decoder.device == torch.device("cpu"):
|
| 290 |
+
self.model.decoder.to(self.device_torch)
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
model: AceStep15 = self.model
|
| 293 |
+
tt = timestep.to(self.device_torch, dtype=torch.long) / 1000
|
| 294 |
+
latent_len = latent_model_input.shape[1]
|
| 295 |
+
device = self.device_torch
|
| 296 |
+
dtype = self.torch_dtype
|
| 297 |
+
attn = torch.ones(1, latent_len, device=device, dtype=dtype)
|
| 298 |
+
|
| 299 |
+
# build context from silence latent matching the actual input length
|
| 300 |
+
sil = get_silence_latent(latent_len, device, dtype) # [1, 64, T]
|
| 301 |
+
src = sil.transpose(1, 2) # [1, T, 64]
|
| 302 |
+
chunk_masks = torch.ones_like(src)
|
| 303 |
+
context = torch.cat([src, chunk_masks], dim=-1) # [1, T, 128]
|
| 304 |
+
|
| 305 |
+
pred = model.decoder(
|
| 306 |
+
x=latent_model_input.detach(),
|
| 307 |
+
timestep=tt.detach(),
|
| 308 |
+
timestep_r=tt.detach(),
|
| 309 |
+
attention_mask=attn.detach(),
|
| 310 |
+
enc_h=text_embeddings.text_embeds.to(self.device_torch, dtype=self.torch_dtype).detach(),
|
| 311 |
+
enc_m=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.bool).detach(),
|
| 312 |
+
context=context.detach(),
|
| 313 |
+
)
|
| 314 |
+
return pred
|
| 315 |
+
|
| 316 |
+
def get_loss_target(self, *args, **kwargs):
|
| 317 |
+
noise = kwargs.get("noise")
|
| 318 |
+
batch = kwargs.get("batch")
|
| 319 |
+
return (noise - batch.latents).detach()
|
| 320 |
+
|
| 321 |
+
def encode_audio(self, audio_tensor: torch.Tensor, device=None, dtype=None):
|
| 322 |
+
if device is None:
|
| 323 |
+
device = self.device_torch
|
| 324 |
+
if dtype is None:
|
| 325 |
+
dtype = self.torch_dtype
|
| 326 |
+
if self.vae.device == torch.device("cpu"):
|
| 327 |
+
self.vae.to(device)
|
| 328 |
+
output = self.vae.encode(audio_tensor.to(device=device, dtype=dtype))
|
| 329 |
+
# transpose from [B, 64, T] to [B, T, 64] for DiT
|
| 330 |
+
output = output.transpose(1, 2).contiguous()
|
| 331 |
+
return output
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class AceStep15XLModel(AceStep15Model):
|
| 335 |
+
arch = "ace_step_15_xl"
|
ai-toolkit/extensions_built_in/audio_models/ace_step/src/__init__.py
ADDED
|
File without changes
|