| """ |
| A simple implementation of conditional flow matching for generating anime faces. |
| """ |
|
|
| import argparse |
| import pickle |
| import random |
| import time |
| from pathlib import Path |
|
|
| import jax |
| import jax.numpy as jnp |
| import kagglehub |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import optax |
| import ot |
| import yaml |
| from flax import nnx |
| from jax.experimental import ode |
| from PIL import Image |
| from tqdm.cli import tqdm |
|
|
| from model import DiT, DiTConfig |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config", type=str, default="config.yaml", help="Path to config file" |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_config(config_path): |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
| return config |
|
|
|
|
| def gen_data_batches(data, batch_size): |
| N = data.shape[0] |
| while True: |
| random_indices = np.random.choice(N, size=batch_size, replace=False) |
| batch = data[random_indices] |
| batch = batch.astype(np.float32) / 256 |
| yield batch |
|
|
|
|
| def loss_fn(flow, batch): |
| xt, t, vt = batch |
| velocity = flow(xt, t) |
| loss = jnp.mean(jnp.square(velocity - vt)) |
| return loss |
|
|
|
|
| def train_step(flow, optimizer, rngs, batch): |
| x0, x1 = batch |
| noise = jax.random.uniform(rngs(), shape=x1.shape, minval=0, maxval=1 / 256) |
| x1 = x1 + noise |
| |
| t = jax.random.uniform(rngs(), (x1.shape[0],), minval=0, maxval=1) |
| |
| xt = x0 + (x1 - x0) * t[:, None, None, None] |
| vt = x1 - x0 |
| batch = (xt, t, vt) |
| loss, grads = nnx.value_and_grad(loss_fn)(flow, batch) |
| optimizer.update(grads) |
| return loss |
|
|
|
|
| @jax.jit |
| def train_step_raw(graphdef, state, batch): |
| flow, optimizer, rngs = nnx.merge(graphdef, state) |
| loss = train_step(flow, optimizer, rngs, batch) |
| _, state = nnx.split((flow, optimizer, rngs)) |
| return state, loss |
|
|
|
|
| @jax.jit |
| def sample_images(graphdef, state): |
| flow, _, _ = nnx.merge(graphdef, state) |
|
|
| def flow_fn(y, t): |
| o = flow(y, t[None]) |
| return o |
|
|
| x = jax.random.normal(nnx.Rngs(0)(), shape=(16, 64, 64, 3), dtype=jnp.float32) |
| o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000)) |
| o = jnp.clip(o[-1], 0, 1) |
| return o |
|
|
|
|
| def generate_ot_pairs(x1): |
| n = x1.shape[0] |
| x0 = np.random.randn(*x1.shape) |
| d1 = x1.reshape(n, -1) |
| d0 = x0.reshape(n, -1) |
| |
| M = ot.dist(d0, d1) |
| a, b = np.ones((n,)), np.ones((n,)) |
| G0 = ot.emd(a, b, M) |
| d1 = np.matmul(G0, d1) |
| x1 = d1.reshape(*x1.shape) |
| return x0, x1 |
|
|
|
|
| def plot_new_images(step: int, graphdef, state): |
| images = sample_images(graphdef, state) |
|
|
| plt.figure(figsize=(2, 2)) |
| for i in range(16): |
| plt.subplot(4, 4, i + 1) |
| plt.imshow(images[i]) |
| plt.axis("off") |
| plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
| plt.savefig(f"images_{step:06d}.png") |
| plt.close() |
|
|
|
|
| args = parse_args() |
| config = load_config(args.config) |
|
|
| |
| path = kagglehub.dataset_download("thimac/anime-face-64") |
| data_path = Path(path) / "64x64" |
| print("Path to dataset files:", data_path) |
|
|
| data_dir = data_path |
| image_files = sorted(data_dir.glob("*.jpg")) |
| random.Random(config["data"]["random_seed"]).shuffle(image_files) |
| N = len(image_files) |
| dataset = np.empty((N, 64, 64, 3), dtype=np.uint8) |
| for i, file_path in enumerate(tqdm(image_files)): |
| dataset[i] = Image.open(file_path) |
|
|
| L = int(N * config["data"]["train_split"]) |
| train_data = dataset[:L] |
| test_data = dataset[L:] |
|
|
| plt.figure(figsize=(2, 2)) |
| for i in range(16): |
| plt.subplot(4, 4, i + 1) |
| plt.imshow(train_data[i]) |
| plt.axis("off") |
| plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
| plt.savefig("train_data_samples.png") |
| plt.close() |
|
|
| scheduler = optax.cosine_onecycle_schedule( |
| transition_steps=config["training"]["num_steps"], |
| peak_value=config["training"]["learning_rate"], |
| pct_start=config["training"]["warmup_pct"], |
| ) |
|
|
| gradient_transform = optax.chain( |
| optax.clip_by_global_norm(config["training"]["grad_clip_norm"]), |
| optax.scale_by_adam(), |
| optax.scale_by_schedule(scheduler), |
| optax.add_decayed_weights(config["training"]["weight_decay"]), |
| optax.scale(-1.0), |
| ) |
|
|
| dit_config = DiTConfig( |
| input_dim=config["model"]["input_dim"], |
| hidden_dim=config["model"]["hidden_dim"], |
| num_blocks=config["model"]["num_blocks"], |
| num_heads=config["model"]["num_heads"], |
| patch_size=config["model"]["patch_size"], |
| patch_stride=config["model"]["patch_stride"], |
| time_freq_dim=config["model"]["time_freq_dim"], |
| time_max_period=config["model"]["time_max_period"], |
| mlp_ratio=config["model"]["mlp_ratio"], |
| use_bias=config["model"]["use_bias"], |
| padding=config["model"]["padding"], |
| pos_embed_cls_token=config["model"]["pos_embed_cls_token"], |
| pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"], |
| ) |
|
|
| flow = DiT(dit_config, rngs=nnx.Rngs(0)) |
| optimizer = nnx.Optimizer(flow, gradient_transform) |
|
|
| rngs = nnx.Rngs(0) |
| graphdef, state = nnx.split((flow, optimizer, rngs)) |
| train_data_iter = gen_data_batches(train_data, config["training"]["batch_size"]) |
|
|
| start = time.perf_counter() |
| losses = [] |
| ckpt_path = config["checkpointing"].get("resume_from_checkpoint") |
| if ckpt_path: |
| del state |
| with open(ckpt_path, "rb") as f: |
| state = pickle.load(f) |
| print(f"Resuming from checkpoint {ckpt_path}") |
| step_str = Path(ckpt_path).stem.split("_")[-1] |
| start_step = int(step_str) + 1 |
| else: |
| start_step = 1 |
|
|
| for step, batch in enumerate(train_data_iter, start=start_step): |
| x0, x1 = generate_ot_pairs(batch) |
| state, loss = train_step_raw(graphdef, state, (x0, x1)) |
|
|
| if step % 100 == 0: |
| losses.append(loss.item()) |
|
|
| if step % config["checkpointing"]["log_every"] == 0: |
| end = time.perf_counter() |
| duration = end - start |
| loss = sum(losses) / len(losses) |
| start = time.perf_counter() |
| losses = [] |
| print(f"step {step:06d} loss {loss:.3f} duration {duration:.3f}s", flush=True) |
|
|
| if step % config["checkpointing"]["plot_every"] == 0: |
| plot_new_images(step, graphdef, state) |
|
|
| if step % config["checkpointing"]["save_every"] == 0: |
| |
| with open(f"state_{step:06d}.ckpt", "wb") as f: |
| pickle.dump(state, f) |
|
|