Text-to-Image
Diffusers
Safetensors
English
image-generation
class-conditional
imagenet
pixelflow
flow-matching
Instructions to use BiliSakura/PixelFlow-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/PixelFlow-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/PixelFlow-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "golden retriever" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
Update PixelFlow-256/pipeline.py
Browse files- PixelFlow-256/pipeline.py +127 -77
PixelFlow-256/pipeline.py
CHANGED
|
@@ -12,6 +12,10 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import importlib
|
| 16 |
import json
|
| 17 |
import math
|
|
@@ -19,7 +23,6 @@ import sys
|
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
-
import numpy as np
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
from einops import rearrange
|
|
@@ -27,8 +30,6 @@ from einops import rearrange
|
|
| 27 |
from diffusers.image_processor import VaeImageProcessor
|
| 28 |
from diffusers.models.embeddings import get_2d_rotary_pos_embed
|
| 29 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 30 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 31 |
-
from diffusers.utils import replace_example_docstring
|
| 32 |
from diffusers.utils.torch_utils import randn_tensor
|
| 33 |
|
| 34 |
|
|
@@ -38,32 +39,19 @@ EXAMPLE_DOC_STRING = """
|
|
| 38 |
Examples:
|
| 39 |
```py
|
| 40 |
>>> from pathlib import Path
|
|
|
|
| 41 |
>>> import torch
|
| 42 |
-
>>> from diffusers import DiffusionPipeline
|
| 43 |
|
| 44 |
>>> model_dir = Path("./PixelFlow-256").resolve()
|
| 45 |
-
>>>
|
|
|
|
|
|
|
|
|
|
| 46 |
... str(model_dir),
|
| 47 |
... local_files_only=True,
|
| 48 |
-
... custom_pipeline=str(model_dir / "pipeline.py"),
|
| 49 |
-
... trust_remote_code=True,
|
| 50 |
... torch_dtype=torch.bfloat16,
|
| 51 |
... )
|
| 52 |
-
>>> pipe
|
| 53 |
-
|
| 54 |
-
>>> print(pipe.id2label[207])
|
| 55 |
-
>>> print(pipe.get_label_ids("golden retriever"))
|
| 56 |
-
|
| 57 |
-
>>> generator = torch.Generator(device="cuda").manual_seed(42)
|
| 58 |
-
>>> image = pipe(
|
| 59 |
-
... class_labels="golden retriever",
|
| 60 |
-
... height=256,
|
| 61 |
-
... width=256,
|
| 62 |
-
... num_inference_steps=[10, 10, 10, 10],
|
| 63 |
-
... guidance_scale=4.0,
|
| 64 |
-
... generator=generator,
|
| 65 |
-
... ).images[0]
|
| 66 |
-
>>> image.save("demo.png")
|
| 67 |
```
|
| 68 |
"""
|
| 69 |
|
|
@@ -76,12 +64,27 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 76 |
Parameters:
|
| 77 |
transformer ([`PixelFlowTransformer2DModel`]):
|
| 78 |
Class-conditional PixelFlow transformer operating in pixel space.
|
| 79 |
-
scheduler ([`PixelFlowScheduler`]
|
| 80 |
-
Multi-stage flow scheduler used by PixelFlow
|
| 81 |
id2label (`dict[int, str]`, *optional*):
|
| 82 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 83 |
"""
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
model_cpu_offload_seq = "transformer"
|
| 86 |
|
| 87 |
def __init__(
|
|
@@ -128,7 +131,6 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 128 |
variant = variant / subfolder
|
| 129 |
|
| 130 |
id2label_override = kwargs.pop("id2label", None)
|
| 131 |
-
kwargs.pop("trust_remote_code", None)
|
| 132 |
model_kwargs = dict(kwargs)
|
| 133 |
scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
|
| 134 |
inserted = []
|
|
@@ -147,14 +149,15 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 147 |
transformer_cls = getattr(importlib.import_module("transformer_pixelflow"), "PixelFlowTransformer2DModel")
|
| 148 |
transformer = transformer_cls.from_pretrained(str(transformer_dir), **model_kwargs)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
-
_ensure_path(str(
|
| 155 |
scheduler_cls = getattr(importlib.import_module("scheduling_pixelflow"), "PixelFlowScheduler")
|
| 156 |
try:
|
| 157 |
-
scheduler = scheduler_cls.from_pretrained(str(
|
| 158 |
except Exception:
|
| 159 |
scheduler = scheduler_cls(**scheduler_kwargs)
|
| 160 |
|
|
@@ -168,11 +171,41 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 168 |
if comp_path in sys.path:
|
| 169 |
sys.path.remove(comp_path)
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def _ensure_labels_loaded(self) -> None:
|
| 178 |
if self._labels_loaded_from_model_index:
|
|
@@ -183,6 +216,12 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 183 |
self.labels = self._build_label2id(self._id2label)
|
| 184 |
self._labels_loaded_from_model_index = True
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
@staticmethod
|
| 187 |
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 188 |
if not variant_path:
|
|
@@ -304,15 +343,23 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 304 |
channels: int,
|
| 305 |
height: int,
|
| 306 |
width: int,
|
|
|
|
|
|
|
|
|
|
| 307 |
eps: float = 1e-6,
|
| 308 |
) -> torch.Tensor:
|
| 309 |
gamma = self.scheduler.gamma
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
)
|
| 314 |
block_number = batch_size * channels * (height // 2) * (width // 2)
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
return rearrange(
|
| 317 |
noise,
|
| 318 |
"(b c h w) (p q) -> b c (h p) (w q)",
|
|
@@ -331,6 +378,7 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 331 |
height: int,
|
| 332 |
width: int,
|
| 333 |
device: torch.device,
|
|
|
|
| 334 |
) -> torch.Tensor:
|
| 335 |
latents = F.interpolate(latents, size=(height, width), mode="nearest")
|
| 336 |
original_start_t = self.scheduler.original_start_t[stage_idx]
|
|
@@ -338,8 +386,12 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 338 |
alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
|
| 339 |
beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
|
| 340 |
|
| 341 |
-
noise = self._sample_block_noise(
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
return alpha * latents + beta * noise
|
| 344 |
|
| 345 |
def _prepare_rope_pos_embed(self, latents: torch.Tensor, device: torch.device) -> torch.Tensor:
|
|
@@ -378,7 +430,6 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 378 |
raise ValueError(f"output_type must be one of: 'pil', 'np', 'pt', 'latent'. Got {output_type}.")
|
| 379 |
|
| 380 |
@torch.inference_mode()
|
| 381 |
-
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 382 |
def __call__(
|
| 383 |
self,
|
| 384 |
class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
|
|
@@ -394,9 +445,6 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 394 |
r"""
|
| 395 |
Generate class-conditional images with PixelFlow.
|
| 396 |
|
| 397 |
-
Examples:
|
| 398 |
-
<!-- this section is replaced by replace_example_docstring -->
|
| 399 |
-
|
| 400 |
Args:
|
| 401 |
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
|
| 402 |
ImageNet class indices or human-readable English label strings.
|
|
@@ -435,37 +483,39 @@ class PixelFlowPipeline(DiffusionPipeline):
|
|
| 435 |
autocast_enabled = device.type == "cuda"
|
| 436 |
autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
| 469 |
|
| 470 |
image = self.decode_latents(latents, output_type=output_type)
|
| 471 |
self.maybe_free_model_hooks()
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
|
| 19 |
import importlib
|
| 20 |
import json
|
| 21 |
import math
|
|
|
|
| 23 |
from pathlib import Path
|
| 24 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 25 |
|
|
|
|
| 26 |
import torch
|
| 27 |
import torch.nn.functional as F
|
| 28 |
from einops import rearrange
|
|
|
|
| 30 |
from diffusers.image_processor import VaeImageProcessor
|
| 31 |
from diffusers.models.embeddings import get_2d_rotary_pos_embed
|
| 32 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
|
|
|
| 33 |
from diffusers.utils.torch_utils import randn_tensor
|
| 34 |
|
| 35 |
|
|
|
|
| 39 |
Examples:
|
| 40 |
```py
|
| 41 |
>>> from pathlib import Path
|
| 42 |
+
>>> import sys
|
| 43 |
>>> import torch
|
|
|
|
| 44 |
|
| 45 |
>>> model_dir = Path("./PixelFlow-256").resolve()
|
| 46 |
+
>>> sys.path.insert(0, str(model_dir))
|
| 47 |
+
>>> from pipeline import PixelFlowPipeline
|
| 48 |
+
|
| 49 |
+
>>> pipe = PixelFlowPipeline.from_pretrained(
|
| 50 |
... str(model_dir),
|
| 51 |
... local_files_only=True,
|
|
|
|
|
|
|
| 52 |
... torch_dtype=torch.bfloat16,
|
| 53 |
... )
|
| 54 |
+
>>> pipe.to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
```
|
| 56 |
"""
|
| 57 |
|
|
|
|
| 64 |
Parameters:
|
| 65 |
transformer ([`PixelFlowTransformer2DModel`]):
|
| 66 |
Class-conditional PixelFlow transformer operating in pixel space.
|
| 67 |
+
scheduler ([`PixelFlowScheduler`]):
|
| 68 |
+
Multi-stage flow scheduler used by PixelFlow.
|
| 69 |
id2label (`dict[int, str]`, *optional*):
|
| 70 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 71 |
"""
|
| 72 |
|
| 73 |
+
@staticmethod
|
| 74 |
+
def prepare_extra_step_kwargs(
|
| 75 |
+
scheduler,
|
| 76 |
+
generator=None,
|
| 77 |
+
eta: float | None = None,
|
| 78 |
+
):
|
| 79 |
+
kwargs = {}
|
| 80 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 81 |
+
if "generator" in step_params:
|
| 82 |
+
kwargs["generator"] = generator
|
| 83 |
+
if eta is not None and "eta" in step_params:
|
| 84 |
+
kwargs["eta"] = eta
|
| 85 |
+
return kwargs
|
| 86 |
+
|
| 87 |
+
|
| 88 |
model_cpu_offload_seq = "transformer"
|
| 89 |
|
| 90 |
def __init__(
|
|
|
|
| 131 |
variant = variant / subfolder
|
| 132 |
|
| 133 |
id2label_override = kwargs.pop("id2label", None)
|
|
|
|
| 134 |
model_kwargs = dict(kwargs)
|
| 135 |
scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
|
| 136 |
inserted = []
|
|
|
|
| 149 |
transformer_cls = getattr(importlib.import_module("transformer_pixelflow"), "PixelFlowTransformer2DModel")
|
| 150 |
transformer = transformer_cls.from_pretrained(str(transformer_dir), **model_kwargs)
|
| 151 |
|
| 152 |
+
scheduling_py = variant / "scheduling_pixelflow.py"
|
| 153 |
+
scheduler_cfg_dir = variant / "scheduler"
|
| 154 |
+
if not scheduling_py.is_file() or not (scheduler_cfg_dir / "scheduler_config.json").exists():
|
| 155 |
+
raise FileNotFoundError(f"Expected scheduler module at {scheduling_py} and config in {scheduler_cfg_dir}")
|
| 156 |
|
| 157 |
+
_ensure_path(str(variant.resolve()))
|
| 158 |
scheduler_cls = getattr(importlib.import_module("scheduling_pixelflow"), "PixelFlowScheduler")
|
| 159 |
try:
|
| 160 |
+
scheduler = scheduler_cls.from_pretrained(str(scheduler_cfg_dir), **scheduler_kwargs)
|
| 161 |
except Exception:
|
| 162 |
scheduler = scheduler_cls(**scheduler_kwargs)
|
| 163 |
|
|
|
|
| 171 |
if comp_path in sys.path:
|
| 172 |
sys.path.remove(comp_path)
|
| 173 |
|
| 174 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 175 |
+
model_kwargs = dict(kwargs)
|
| 176 |
+
transformer_subfolder = model_kwargs.pop("transformer_subfolder", None)
|
| 177 |
+
scheduler_subfolder = model_kwargs.pop("scheduler_subfolder", None)
|
| 178 |
+
scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
|
| 179 |
+
base_path = Path(pretrained_model_name_or_path)
|
| 180 |
+
|
| 181 |
+
if transformer_subfolder is None and (base_path / "transformer").exists():
|
| 182 |
+
transformer_subfolder = "transformer"
|
| 183 |
+
if scheduler_subfolder is None and (base_path / "scheduler").exists():
|
| 184 |
+
scheduler_subfolder = "scheduler"
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 188 |
+
except Exception:
|
| 189 |
+
if transformer_subfolder is not None:
|
| 190 |
+
transformer_path = str(base_path / transformer_subfolder)
|
| 191 |
+
else:
|
| 192 |
+
transformer_path = pretrained_model_name_or_path
|
| 193 |
+
|
| 194 |
+
transformer = PixelFlowTransformer2DModel.from_pretrained(transformer_path, **model_kwargs)
|
| 195 |
+
try:
|
| 196 |
+
scheduler = PixelFlowScheduler.from_pretrained(
|
| 197 |
+
pretrained_model_name_or_path,
|
| 198 |
+
subfolder=scheduler_subfolder,
|
| 199 |
+
**scheduler_kwargs,
|
| 200 |
+
)
|
| 201 |
+
except Exception:
|
| 202 |
+
scheduler = PixelFlowScheduler(**scheduler_kwargs)
|
| 203 |
+
|
| 204 |
+
id2label = cls._read_id2label_from_model_index(str(base_path))
|
| 205 |
+
pipe = cls(transformer=transformer, scheduler=scheduler, id2label=id2label)
|
| 206 |
+
if hasattr(pipe, "register_to_config"):
|
| 207 |
+
pipe.register_to_config(_name_or_path=str(base_path))
|
| 208 |
+
return pipe
|
| 209 |
|
| 210 |
def _ensure_labels_loaded(self) -> None:
|
| 211 |
if self._labels_loaded_from_model_index:
|
|
|
|
| 216 |
self.labels = self._build_label2id(self._id2label)
|
| 217 |
self._labels_loaded_from_model_index = True
|
| 218 |
|
| 219 |
+
@staticmethod
|
| 220 |
+
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
| 221 |
+
if not id2label:
|
| 222 |
+
return {}
|
| 223 |
+
return {int(key): value for key, value in id2label.items()}
|
| 224 |
+
|
| 225 |
@staticmethod
|
| 226 |
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 227 |
if not variant_path:
|
|
|
|
| 343 |
channels: int,
|
| 344 |
height: int,
|
| 345 |
width: int,
|
| 346 |
+
device: torch.device,
|
| 347 |
+
dtype: torch.dtype,
|
| 348 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 349 |
eps: float = 1e-6,
|
| 350 |
) -> torch.Tensor:
|
| 351 |
gamma = self.scheduler.gamma
|
| 352 |
+
cov = torch.eye(4, dtype=torch.float32) * (1 - gamma) + torch.ones(4, 4, dtype=torch.float32) * gamma
|
| 353 |
+
cov = cov + eps * torch.eye(4, dtype=torch.float32)
|
| 354 |
+
chol = torch.linalg.cholesky(cov).to(device=device, dtype=dtype)
|
|
|
|
| 355 |
block_number = batch_size * channels * (height // 2) * (width // 2)
|
| 356 |
+
standard = randn_tensor(
|
| 357 |
+
(block_number, 4),
|
| 358 |
+
generator=generator,
|
| 359 |
+
device=device,
|
| 360 |
+
dtype=dtype,
|
| 361 |
+
)
|
| 362 |
+
noise = standard @ chol.T
|
| 363 |
return rearrange(
|
| 364 |
noise,
|
| 365 |
"(b c h w) (p q) -> b c (h p) (w q)",
|
|
|
|
| 378 |
height: int,
|
| 379 |
width: int,
|
| 380 |
device: torch.device,
|
| 381 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 382 |
) -> torch.Tensor:
|
| 383 |
latents = F.interpolate(latents, size=(height, width), mode="nearest")
|
| 384 |
original_start_t = self.scheduler.original_start_t[stage_idx]
|
|
|
|
| 386 |
alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
|
| 387 |
beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
|
| 388 |
|
| 389 |
+
noise = self._sample_block_noise(
|
| 390 |
+
*latents.shape,
|
| 391 |
+
device=device,
|
| 392 |
+
dtype=latents.dtype,
|
| 393 |
+
generator=generator,
|
| 394 |
+
)
|
| 395 |
return alpha * latents + beta * noise
|
| 396 |
|
| 397 |
def _prepare_rope_pos_embed(self, latents: torch.Tensor, device: torch.device) -> torch.Tensor:
|
|
|
|
| 430 |
raise ValueError(f"output_type must be one of: 'pil', 'np', 'pt', 'latent'. Got {output_type}.")
|
| 431 |
|
| 432 |
@torch.inference_mode()
|
|
|
|
| 433 |
def __call__(
|
| 434 |
self,
|
| 435 |
class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
|
|
|
|
| 445 |
r"""
|
| 446 |
Generate class-conditional images with PixelFlow.
|
| 447 |
|
|
|
|
|
|
|
|
|
|
| 448 |
Args:
|
| 449 |
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
|
| 450 |
ImageNet class indices or human-readable English label strings.
|
|
|
|
| 483 |
autocast_enabled = device.type == "cuda"
|
| 484 |
autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
|
| 485 |
|
| 486 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 487 |
+
|
| 488 |
+
for stage_idx in range(self.scheduler.num_stages):
|
| 489 |
+
self.scheduler.set_timesteps(stage_steps[stage_idx], stage_idx, device=device, shift=shift)
|
| 490 |
+
timesteps = self.scheduler.Timesteps
|
| 491 |
+
|
| 492 |
+
if stage_idx > 0:
|
| 493 |
+
height, width = height * 2, width * 2
|
| 494 |
+
latents = self._upsample_latents_for_stage(
|
| 495 |
+
latents, stage_idx, height, width, device, generator=generator
|
| 496 |
+
)
|
| 497 |
+
size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
|
| 498 |
+
|
| 499 |
+
rope_pos = self._prepare_rope_pos_embed(latents, device)
|
| 500 |
+
|
| 501 |
+
for timestep in timesteps:
|
| 502 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 503 |
+
timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
| 504 |
+
with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
|
| 505 |
+
noise_pred = self.transformer(
|
| 506 |
+
latent_model_input,
|
| 507 |
+
timestep=timestep_batch,
|
| 508 |
+
class_labels=conditioning,
|
| 509 |
+
latent_size=size_tensor,
|
| 510 |
+
pos_embed=rope_pos,
|
| 511 |
+
).sample
|
| 512 |
+
|
| 513 |
+
if do_classifier_free_guidance:
|
| 514 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 515 |
+
stage_scale = self._stage_guidance_scale(stage_idx, guidance_scale)
|
| 516 |
+
noise_pred = noise_pred_uncond + stage_scale * (noise_pred_text - noise_pred_uncond)
|
| 517 |
+
|
| 518 |
+
latents = self.scheduler.step(model_output=noise_pred, sample=latents, **extra_step_kwargs).prev_sample
|
| 519 |
|
| 520 |
image = self.decode_latents(latents, output_type=output_type)
|
| 521 |
self.maybe_free_model_hooks()
|