| import math |
| from dataclasses import dataclass |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from flax import nnx |
|
|
|
|
| @dataclass |
| class DiTConfig: |
| input_dim: int |
| hidden_dim: int |
| num_blocks: int |
| num_heads: int |
| patch_size: int |
| patch_stride: int |
| time_freq_dim: int |
| time_max_period: int |
| mlp_ratio: int |
| use_bias: bool |
| padding: str |
| pos_embed_cls_token: bool |
| pos_embed_extra_tokens: int |
|
|
|
|
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): |
| """ |
| grid_size: int of the grid height and width |
| return: |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = jnp.arange(grid_size, dtype=jnp.float32) |
| grid_w = jnp.arange(grid_size, dtype=jnp.float32) |
| grid = jnp.meshgrid(grid_w, grid_h) |
| grid = jnp.stack(grid, axis=0) |
|
|
| grid = grid.reshape([2, 1, grid_size, grid_size]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token and extra_tokens > 0: |
| pos_embed = np.concatenate( |
| [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 |
| ) |
| return pos_embed |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| assert embed_dim % 2 == 0 |
|
|
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
| emb = jnp.concatenate([emb_h, emb_w], axis=1) |
| return emb |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = jnp.arange(embed_dim // 2, dtype=np.float32) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / 16**omega |
|
|
| pos = pos.reshape(-1) |
| out = jnp.einsum("m,d->md", pos, omega) |
|
|
| emb_sin = jnp.sin(out) |
| emb_cos = jnp.cos(out) |
|
|
| emb = jnp.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
|
|
| class PatchEmbedding(nnx.Module): |
| """Patch embedding module.""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.cnn = nnx.Conv( |
| config.input_dim, |
| config.hidden_dim, |
| kernel_size=(config.patch_size, config.patch_size), |
| strides=(config.patch_stride, config.patch_stride), |
| padding=config.padding, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ) |
|
|
| def __call__(self, x): |
| return self.cnn(x) |
|
|
|
|
| class TimeEmbedding(nnx.Module): |
| """Time embedding module.""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.freq_dim = config.time_freq_dim |
| self.max_period = config.time_max_period |
| self.fc1 = nnx.Linear( |
| self.freq_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs, dtype=jnp.bfloat16 |
| ) |
| self.fc2 = nnx.Linear( |
| config.hidden_dim, config.hidden_dim, use_bias=config.use_bias, rngs=rngs, dtype=jnp.bfloat16 |
| ) |
|
|
| @staticmethod |
| def cosine_embedding(t, dim, max_period): |
| assert dim % 2 == 0 |
| half = dim // 2 |
| freqs = jnp.exp( |
| -math.log(max_period) |
| * jnp.arange(start=0, stop=half, dtype=jnp.float32) |
| / half |
| ) |
| args = t[:, None] * freqs[None, :] * 1024 |
| embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) |
| return embedding |
|
|
| def __call__(self, t): |
| t_freq = self.cosine_embedding(t, self.freq_dim, self.max_period) |
| t_embed = self.fc1(t_freq) |
| t_embed = nnx.silu(t_embed) |
| t_embed = self.fc2(t_embed) |
| return t_embed |
|
|
|
|
| class MLP(nnx.Module): |
| """MLP module.""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.fc1 = nnx.Linear( |
| config.hidden_dim, |
| config.hidden_dim * config.mlp_ratio, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ) |
| self.fc2 = nnx.Linear( |
| config.hidden_dim * config.mlp_ratio, |
| config.hidden_dim, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ) |
|
|
| def __call__(self, x): |
| x = self.fc1(x) |
| x = nnx.silu(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class SelfAttention(nnx.Module): |
| """Self attention module.""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.fc = nnx.Linear( |
| config.hidden_dim, |
| 3 * config.hidden_dim, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ) |
| self.heads = config.num_heads |
| self.head_dim = config.hidden_dim // config.num_heads |
| assert config.hidden_dim % config.num_heads == 0 |
| self.q_norm = nnx.RMSNorm(num_features=self.head_dim, use_scale=True, rngs=rngs) |
| self.k_norm = nnx.RMSNorm(num_features=self.head_dim, use_scale=True, rngs=rngs) |
|
|
| def __call__(self, x): |
| q, k, v = jnp.split(self.fc(x), 3, axis=-1) |
| |
| q = q.reshape(q.shape[0], q.shape[1], self.heads, self.head_dim) |
| k = k.reshape(k.shape[0], k.shape[1], self.heads, self.head_dim) |
| v = v.reshape(v.shape[0], v.shape[1], self.heads, self.head_dim) |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
| o = jax.nn.dot_product_attention(q, k, v, is_causal=False) |
| o = o.reshape(o.shape[0], o.shape[1], self.heads * self.head_dim) |
| return o |
|
|
|
|
| def modulate(x, shift, scale): |
| return x * (1 + scale[:, None, :]) + shift[:, None, :] |
|
|
|
|
| class TransformerBlock(nnx.Module): |
| """Transformer block.""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.norm1 = nnx.RMSNorm( |
| num_features=config.hidden_dim, use_scale=False, rngs=rngs |
| ) |
| self.attn = SelfAttention(config, rngs=rngs) |
| self.norm2 = nnx.RMSNorm( |
| num_features=config.hidden_dim, use_scale=False, rngs=rngs |
| ) |
| self.mlp = MLP(config, rngs=rngs) |
| self.adalm_modulation = nnx.Sequential( |
| nnx.silu, |
| nnx.Linear( |
| config.hidden_dim, |
| 6 * config.hidden_dim, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ), |
| ) |
|
|
| def __call__(self, x, c): |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split( |
| self.adalm_modulation(c), 6, axis=-1 |
| ) |
| attn_x = self.norm1(x) |
| attn_x = modulate(attn_x, shift_msa, scale_msa) |
| x = x + gate_msa[:, None, :] * self.attn(attn_x) |
| mlp_x = self.norm2(x) |
| mlp_x = modulate(mlp_x, shift_mlp, scale_mlp) |
| x = x + gate_mlp[:, None, :] * self.mlp(mlp_x) |
| return x |
|
|
|
|
| class FinalLayer(nnx.Module): |
| """Final layer.""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.norm = nnx.RMSNorm( |
| num_features=config.hidden_dim, use_scale=False, rngs=rngs |
| ) |
| self.conv = nnx.ConvTranspose( |
| config.hidden_dim, |
| config.input_dim, |
| kernel_size=(config.patch_size, config.patch_size), |
| strides=(config.patch_stride, config.patch_stride), |
| padding=config.padding, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ) |
| self.adalm_modulation = nnx.Sequential( |
| nnx.silu, |
| nnx.Linear( |
| config.hidden_dim, |
| 2 * config.hidden_dim, |
| use_bias=config.use_bias, |
| rngs=rngs, |
| dtype=jnp.bfloat16, |
| ), |
| ) |
|
|
| def __call__(self, x, c): |
| shift, scale = jnp.split(self.adalm_modulation(c), 2, axis=-1) |
| x = self.norm(x) |
| x = modulate(x, shift, scale) |
| |
| H = W = int(x.shape[1] ** 0.5) |
| x = x.reshape(x.shape[0], H, W, x.shape[-1]) |
| x = self.conv(x) |
| return x |
|
|
|
|
| class DiT(nnx.Module): |
| """Diffusion Transformer""" |
|
|
| def __init__(self, config: DiTConfig, *, rngs: nnx.Rngs): |
| super().__init__() |
| self.config = config |
| self.time_embedding = TimeEmbedding(config, rngs=rngs) |
| self.patch_embedding = PatchEmbedding(config, rngs=rngs) |
| self.blocks = [ |
| TransformerBlock(config, rngs=rngs) for _ in range(config.num_blocks) |
| ] |
| self.final_layer = FinalLayer(config, rngs=rngs) |
|
|
| def __call__(self, xt, t): |
| t = self.time_embedding(t) |
| x = self.patch_embedding(xt) |
| N, H, W, D = x.shape |
| x = x.reshape(N, H * W, D) |
| x = x + get_2d_sincos_pos_embed( |
| D, |
| H, |
| cls_token=self.config.pos_embed_cls_token, |
| extra_tokens=self.config.pos_embed_extra_tokens, |
| ).reshape(1, H * W, D) |
| c = t |
| for block in self.blocks: |
| x = block(x, c) |
| x = self.final_layer(x, c) |
| return x |
|
|