BiliSakura commited on
Commit
a43ace8
·
verified ·
1 Parent(s): 4fc1a18

Update PixelFlow-256/pipeline.py

Browse files
Files changed (1) hide show
  1. PixelFlow-256/pipeline.py +127 -77
PixelFlow-256/pipeline.py CHANGED
@@ -12,6 +12,10 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
 
 
15
  import importlib
16
  import json
17
  import math
@@ -19,7 +23,6 @@ import sys
19
  from pathlib import Path
20
  from typing import Any, Dict, List, Optional, Tuple, Union
21
 
22
- import numpy as np
23
  import torch
24
  import torch.nn.functional as F
25
  from einops import rearrange
@@ -27,8 +30,6 @@ from einops import rearrange
27
  from diffusers.image_processor import VaeImageProcessor
28
  from diffusers.models.embeddings import get_2d_rotary_pos_embed
29
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30
- from diffusers.schedulers import KarrasDiffusionSchedulers
31
- from diffusers.utils import replace_example_docstring
32
  from diffusers.utils.torch_utils import randn_tensor
33
 
34
 
@@ -38,32 +39,19 @@ EXAMPLE_DOC_STRING = """
38
  Examples:
39
  ```py
40
  >>> from pathlib import Path
 
41
  >>> import torch
42
- >>> from diffusers import DiffusionPipeline
43
 
44
  >>> model_dir = Path("./PixelFlow-256").resolve()
45
- >>> pipe = DiffusionPipeline.from_pretrained(
 
 
 
46
  ... str(model_dir),
47
  ... local_files_only=True,
48
- ... custom_pipeline=str(model_dir / "pipeline.py"),
49
- ... trust_remote_code=True,
50
  ... torch_dtype=torch.bfloat16,
51
  ... )
52
- >>> pipe = pipe.to("cuda")
53
-
54
- >>> print(pipe.id2label[207])
55
- >>> print(pipe.get_label_ids("golden retriever"))
56
-
57
- >>> generator = torch.Generator(device="cuda").manual_seed(42)
58
- >>> image = pipe(
59
- ... class_labels="golden retriever",
60
- ... height=256,
61
- ... width=256,
62
- ... num_inference_steps=[10, 10, 10, 10],
63
- ... guidance_scale=4.0,
64
- ... generator=generator,
65
- ... ).images[0]
66
- >>> image.save("demo.png")
67
  ```
68
  """
69
 
@@ -76,12 +64,27 @@ class PixelFlowPipeline(DiffusionPipeline):
76
  Parameters:
77
  transformer ([`PixelFlowTransformer2DModel`]):
78
  Class-conditional PixelFlow transformer operating in pixel space.
79
- scheduler ([`PixelFlowScheduler`] or [`KarrasDiffusionSchedulers`]):
80
- Multi-stage flow scheduler used by PixelFlow cascade denoising.
81
  id2label (`dict[int, str]`, *optional*):
82
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
83
  """
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  model_cpu_offload_seq = "transformer"
86
 
87
  def __init__(
@@ -128,7 +131,6 @@ class PixelFlowPipeline(DiffusionPipeline):
128
  variant = variant / subfolder
129
 
130
  id2label_override = kwargs.pop("id2label", None)
131
- kwargs.pop("trust_remote_code", None)
132
  model_kwargs = dict(kwargs)
133
  scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
134
  inserted = []
@@ -147,14 +149,15 @@ class PixelFlowPipeline(DiffusionPipeline):
147
  transformer_cls = getattr(importlib.import_module("transformer_pixelflow"), "PixelFlowTransformer2DModel")
148
  transformer = transformer_cls.from_pretrained(str(transformer_dir), **model_kwargs)
149
 
150
- scheduler_dir = variant / "scheduler"
151
- if not (scheduler_dir / "scheduler_config.json").exists():
152
- raise FileNotFoundError(f"Expected scheduler config in {scheduler_dir}")
 
153
 
154
- _ensure_path(str(scheduler_dir))
155
  scheduler_cls = getattr(importlib.import_module("scheduling_pixelflow"), "PixelFlowScheduler")
156
  try:
157
- scheduler = scheduler_cls.from_pretrained(str(scheduler_dir), **scheduler_kwargs)
158
  except Exception:
159
  scheduler = scheduler_cls(**scheduler_kwargs)
160
 
@@ -168,11 +171,41 @@ class PixelFlowPipeline(DiffusionPipeline):
168
  if comp_path in sys.path:
169
  sys.path.remove(comp_path)
170
 
171
- @staticmethod
172
- def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
173
- if not id2label:
174
- return {}
175
- return {int(key): value for key, value in id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  def _ensure_labels_loaded(self) -> None:
178
  if self._labels_loaded_from_model_index:
@@ -183,6 +216,12 @@ class PixelFlowPipeline(DiffusionPipeline):
183
  self.labels = self._build_label2id(self._id2label)
184
  self._labels_loaded_from_model_index = True
185
 
 
 
 
 
 
 
186
  @staticmethod
187
  def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
188
  if not variant_path:
@@ -304,15 +343,23 @@ class PixelFlowPipeline(DiffusionPipeline):
304
  channels: int,
305
  height: int,
306
  width: int,
 
 
 
307
  eps: float = 1e-6,
308
  ) -> torch.Tensor:
309
  gamma = self.scheduler.gamma
310
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
311
- torch.zeros(4),
312
- torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4),
313
- )
314
  block_number = batch_size * channels * (height // 2) * (width // 2)
315
- noise = torch.stack([dist.sample() for _ in range(block_number)])
 
 
 
 
 
 
316
  return rearrange(
317
  noise,
318
  "(b c h w) (p q) -> b c (h p) (w q)",
@@ -331,6 +378,7 @@ class PixelFlowPipeline(DiffusionPipeline):
331
  height: int,
332
  width: int,
333
  device: torch.device,
 
334
  ) -> torch.Tensor:
335
  latents = F.interpolate(latents, size=(height, width), mode="nearest")
336
  original_start_t = self.scheduler.original_start_t[stage_idx]
@@ -338,8 +386,12 @@ class PixelFlowPipeline(DiffusionPipeline):
338
  alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
339
  beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
340
 
341
- noise = self._sample_block_noise(*latents.shape)
342
- noise = noise.to(device=device, dtype=latents.dtype)
 
 
 
 
343
  return alpha * latents + beta * noise
344
 
345
  def _prepare_rope_pos_embed(self, latents: torch.Tensor, device: torch.device) -> torch.Tensor:
@@ -378,7 +430,6 @@ class PixelFlowPipeline(DiffusionPipeline):
378
  raise ValueError(f"output_type must be one of: 'pil', 'np', 'pt', 'latent'. Got {output_type}.")
379
 
380
  @torch.inference_mode()
381
- @replace_example_docstring(EXAMPLE_DOC_STRING)
382
  def __call__(
383
  self,
384
  class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
@@ -394,9 +445,6 @@ class PixelFlowPipeline(DiffusionPipeline):
394
  r"""
395
  Generate class-conditional images with PixelFlow.
396
 
397
- Examples:
398
- <!-- this section is replaced by replace_example_docstring -->
399
-
400
  Args:
401
  class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
402
  ImageNet class indices or human-readable English label strings.
@@ -435,37 +483,39 @@ class PixelFlowPipeline(DiffusionPipeline):
435
  autocast_enabled = device.type == "cuda"
436
  autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
437
 
438
- with self.progress_bar(total=sum(stage_steps)) as progress_bar:
439
- for stage_idx in range(self.scheduler.num_stages):
440
- self.scheduler.set_timesteps(stage_steps[stage_idx], stage_idx, device=device, shift=shift)
441
- timesteps = self.scheduler.Timesteps
442
-
443
- if stage_idx > 0:
444
- height, width = height * 2, width * 2
445
- latents = self._upsample_latents_for_stage(latents, stage_idx, height, width, device)
446
- size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
447
-
448
- rope_pos = self._prepare_rope_pos_embed(latents, device)
449
-
450
- for timestep in timesteps:
451
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
452
- timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
453
- with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
454
- noise_pred = self.transformer(
455
- latent_model_input,
456
- timestep=timestep_batch,
457
- class_labels=conditioning,
458
- latent_size=size_tensor,
459
- pos_embed=rope_pos,
460
- ).sample
461
-
462
- if do_classifier_free_guidance:
463
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
464
- stage_scale = self._stage_guidance_scale(stage_idx, guidance_scale)
465
- noise_pred = noise_pred_uncond + stage_scale * (noise_pred_text - noise_pred_uncond)
466
-
467
- latents = self.scheduler.step(model_output=noise_pred, sample=latents).prev_sample
468
- progress_bar.update()
 
 
469
 
470
  image = self.decode_latents(latents, output_type=output_type)
471
  self.maybe_free_model_hooks()
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ from __future__ import annotations
16
+
17
+ import inspect
18
+
19
  import importlib
20
  import json
21
  import math
 
23
  from pathlib import Path
24
  from typing import Any, Dict, List, Optional, Tuple, Union
25
 
 
26
  import torch
27
  import torch.nn.functional as F
28
  from einops import rearrange
 
30
  from diffusers.image_processor import VaeImageProcessor
31
  from diffusers.models.embeddings import get_2d_rotary_pos_embed
32
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
 
33
  from diffusers.utils.torch_utils import randn_tensor
34
 
35
 
 
39
  Examples:
40
  ```py
41
  >>> from pathlib import Path
42
+ >>> import sys
43
  >>> import torch
 
44
 
45
  >>> model_dir = Path("./PixelFlow-256").resolve()
46
+ >>> sys.path.insert(0, str(model_dir))
47
+ >>> from pipeline import PixelFlowPipeline
48
+
49
+ >>> pipe = PixelFlowPipeline.from_pretrained(
50
  ... str(model_dir),
51
  ... local_files_only=True,
 
 
52
  ... torch_dtype=torch.bfloat16,
53
  ... )
54
+ >>> pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ```
56
  """
57
 
 
64
  Parameters:
65
  transformer ([`PixelFlowTransformer2DModel`]):
66
  Class-conditional PixelFlow transformer operating in pixel space.
67
+ scheduler ([`PixelFlowScheduler`]):
68
+ Multi-stage flow scheduler used by PixelFlow.
69
  id2label (`dict[int, str]`, *optional*):
70
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
71
  """
72
 
73
+ @staticmethod
74
+ def prepare_extra_step_kwargs(
75
+ scheduler,
76
+ generator=None,
77
+ eta: float | None = None,
78
+ ):
79
+ kwargs = {}
80
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
81
+ if "generator" in step_params:
82
+ kwargs["generator"] = generator
83
+ if eta is not None and "eta" in step_params:
84
+ kwargs["eta"] = eta
85
+ return kwargs
86
+
87
+
88
  model_cpu_offload_seq = "transformer"
89
 
90
  def __init__(
 
131
  variant = variant / subfolder
132
 
133
  id2label_override = kwargs.pop("id2label", None)
 
134
  model_kwargs = dict(kwargs)
135
  scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
136
  inserted = []
 
149
  transformer_cls = getattr(importlib.import_module("transformer_pixelflow"), "PixelFlowTransformer2DModel")
150
  transformer = transformer_cls.from_pretrained(str(transformer_dir), **model_kwargs)
151
 
152
+ scheduling_py = variant / "scheduling_pixelflow.py"
153
+ scheduler_cfg_dir = variant / "scheduler"
154
+ if not scheduling_py.is_file() or not (scheduler_cfg_dir / "scheduler_config.json").exists():
155
+ raise FileNotFoundError(f"Expected scheduler module at {scheduling_py} and config in {scheduler_cfg_dir}")
156
 
157
+ _ensure_path(str(variant.resolve()))
158
  scheduler_cls = getattr(importlib.import_module("scheduling_pixelflow"), "PixelFlowScheduler")
159
  try:
160
+ scheduler = scheduler_cls.from_pretrained(str(scheduler_cfg_dir), **scheduler_kwargs)
161
  except Exception:
162
  scheduler = scheduler_cls(**scheduler_kwargs)
163
 
 
171
  if comp_path in sys.path:
172
  sys.path.remove(comp_path)
173
 
174
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
175
+ model_kwargs = dict(kwargs)
176
+ transformer_subfolder = model_kwargs.pop("transformer_subfolder", None)
177
+ scheduler_subfolder = model_kwargs.pop("scheduler_subfolder", None)
178
+ scheduler_kwargs = model_kwargs.pop("scheduler_kwargs", {})
179
+ base_path = Path(pretrained_model_name_or_path)
180
+
181
+ if transformer_subfolder is None and (base_path / "transformer").exists():
182
+ transformer_subfolder = "transformer"
183
+ if scheduler_subfolder is None and (base_path / "scheduler").exists():
184
+ scheduler_subfolder = "scheduler"
185
+
186
+ try:
187
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
188
+ except Exception:
189
+ if transformer_subfolder is not None:
190
+ transformer_path = str(base_path / transformer_subfolder)
191
+ else:
192
+ transformer_path = pretrained_model_name_or_path
193
+
194
+ transformer = PixelFlowTransformer2DModel.from_pretrained(transformer_path, **model_kwargs)
195
+ try:
196
+ scheduler = PixelFlowScheduler.from_pretrained(
197
+ pretrained_model_name_or_path,
198
+ subfolder=scheduler_subfolder,
199
+ **scheduler_kwargs,
200
+ )
201
+ except Exception:
202
+ scheduler = PixelFlowScheduler(**scheduler_kwargs)
203
+
204
+ id2label = cls._read_id2label_from_model_index(str(base_path))
205
+ pipe = cls(transformer=transformer, scheduler=scheduler, id2label=id2label)
206
+ if hasattr(pipe, "register_to_config"):
207
+ pipe.register_to_config(_name_or_path=str(base_path))
208
+ return pipe
209
 
210
  def _ensure_labels_loaded(self) -> None:
211
  if self._labels_loaded_from_model_index:
 
216
  self.labels = self._build_label2id(self._id2label)
217
  self._labels_loaded_from_model_index = True
218
 
219
+ @staticmethod
220
+ def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
221
+ if not id2label:
222
+ return {}
223
+ return {int(key): value for key, value in id2label.items()}
224
+
225
  @staticmethod
226
  def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
227
  if not variant_path:
 
343
  channels: int,
344
  height: int,
345
  width: int,
346
+ device: torch.device,
347
+ dtype: torch.dtype,
348
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
349
  eps: float = 1e-6,
350
  ) -> torch.Tensor:
351
  gamma = self.scheduler.gamma
352
+ cov = torch.eye(4, dtype=torch.float32) * (1 - gamma) + torch.ones(4, 4, dtype=torch.float32) * gamma
353
+ cov = cov + eps * torch.eye(4, dtype=torch.float32)
354
+ chol = torch.linalg.cholesky(cov).to(device=device, dtype=dtype)
 
355
  block_number = batch_size * channels * (height // 2) * (width // 2)
356
+ standard = randn_tensor(
357
+ (block_number, 4),
358
+ generator=generator,
359
+ device=device,
360
+ dtype=dtype,
361
+ )
362
+ noise = standard @ chol.T
363
  return rearrange(
364
  noise,
365
  "(b c h w) (p q) -> b c (h p) (w q)",
 
378
  height: int,
379
  width: int,
380
  device: torch.device,
381
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
382
  ) -> torch.Tensor:
383
  latents = F.interpolate(latents, size=(height, width), mode="nearest")
384
  original_start_t = self.scheduler.original_start_t[stage_idx]
 
386
  alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
387
  beta = alpha * (1 - original_start_t) / math.sqrt(-gamma)
388
 
389
+ noise = self._sample_block_noise(
390
+ *latents.shape,
391
+ device=device,
392
+ dtype=latents.dtype,
393
+ generator=generator,
394
+ )
395
  return alpha * latents + beta * noise
396
 
397
  def _prepare_rope_pos_embed(self, latents: torch.Tensor, device: torch.device) -> torch.Tensor:
 
430
  raise ValueError(f"output_type must be one of: 'pil', 'np', 'pt', 'latent'. Got {output_type}.")
431
 
432
  @torch.inference_mode()
 
433
  def __call__(
434
  self,
435
  class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
 
445
  r"""
446
  Generate class-conditional images with PixelFlow.
447
 
 
 
 
448
  Args:
449
  class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
450
  ImageNet class indices or human-readable English label strings.
 
483
  autocast_enabled = device.type == "cuda"
484
  autocast_dtype = torch.bfloat16 if autocast_enabled else torch.float32
485
 
486
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
487
+
488
+ for stage_idx in range(self.scheduler.num_stages):
489
+ self.scheduler.set_timesteps(stage_steps[stage_idx], stage_idx, device=device, shift=shift)
490
+ timesteps = self.scheduler.Timesteps
491
+
492
+ if stage_idx > 0:
493
+ height, width = height * 2, width * 2
494
+ latents = self._upsample_latents_for_stage(
495
+ latents, stage_idx, height, width, device, generator=generator
496
+ )
497
+ size_tensor = torch.tensor([latents.shape[-1] // self.transformer.patch_size], dtype=torch.int32, device=device)
498
+
499
+ rope_pos = self._prepare_rope_pos_embed(latents, device)
500
+
501
+ for timestep in timesteps:
502
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
503
+ timestep_batch = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
504
+ with torch.autocast(device.type, enabled=autocast_enabled, dtype=autocast_dtype):
505
+ noise_pred = self.transformer(
506
+ latent_model_input,
507
+ timestep=timestep_batch,
508
+ class_labels=conditioning,
509
+ latent_size=size_tensor,
510
+ pos_embed=rope_pos,
511
+ ).sample
512
+
513
+ if do_classifier_free_guidance:
514
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
515
+ stage_scale = self._stage_guidance_scale(stage_idx, guidance_scale)
516
+ noise_pred = noise_pred_uncond + stage_scale * (noise_pred_text - noise_pred_uncond)
517
+
518
+ latents = self.scheduler.step(model_output=noise_pred, sample=latents, **extra_step_kwargs).prev_sample
519
 
520
  image = self.decode_latents(latents, output_type=output_type)
521
  self.maybe_free_model_hooks()