# train_mnist.py # --------------------------------------------------------------------- # FP32 training for MNIST with: # - Sparse W&B logging: loss, global/per-layer norms # - Training speed: iterations/sec and ms/iter # - Periodic sampling + sample speed # - Optional FID (clean-fid) and IS (torch-fidelity) # - Logs diffusion hyperparameters (T, beta schedule, sampling steps, eta) # --------------------------------------------------------------------- import os import time import math import yaml import torch from torch.optim import Adam from tqdm.auto import tqdm from torchvision import datasets, transforms from torchvision.utils import save_image from ema_pytorch import EMA import wandb import numpy as np from torchvision.utils import make_grid # Optional metrics: will be checked at runtime try: from cleanfid import fid as clean_fid HAS_CLEANFID = True except Exception: HAS_CLEANFID = False try: from torch_fidelity import calculate_metrics as tf_calculate_metrics HAS_TORCH_FIDELITY = True except Exception: HAS_TORCH_FIDELITY = False from unet import UNet from diffusion import GaussianDiffusion # your current file name # --------------------------- # Utility functions for making videos # --------------------------- def frames_to_wandb_video(frames, nrow=8, fps=6): """ Convert a list of [B,C,H,W] tensors (values in [0,1]) into a W&B Video. - For each time step: make a grid of the batch (nrow), convert to HxWxC uint8. - Stack along time to build a (T,H,W,C) numpy array. """ np_frames = [] for f in frames: # clamp and make a grid f = f.clamp(0, 1) grid = make_grid(f, nrow=nrow) # [C,H,W] grid = (grid * 255.0).byte().cpu().numpy() # [C,H,W], uint8 grid = np.transpose(grid, (1, 2, 0)) # [H,W,C] np_frames.append(grid) video = np.stack(np_frames, axis=0) # [T,H,W,C] return wandb.Video(video, fps=fps, format="mp4") # --------------------------- # Speedups on CUDA (still FP32) # --------------------------- def maybe_enable_cuda_speedups(cfg): if torch.cuda.is_available(): if cfg.get("compute", {}).get("enable_tf32", True): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True # --------------------------- # MNIST dataloader (32x32, float32) # --------------------------- def get_loader_mnist(bs, nw, img_size): tfm = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), # [0,1], CxHxW transforms.ConvertImageDtype(torch.float32), # force float32 ]) ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm) return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, num_workers=nw, pin_memory=True) # --------------------------- # Sparse norm logging helpers # --------------------------- def log_global_grad_norm_sparsely(model, step, every=1000): """ Logs a single scalar 'train/global_grad_norm' every `every` steps. """ if (step % every) != 0: return with torch.no_grad(): norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None] if len(norms) == 0: return global_norm = float(torch.tensor(norms).norm().item()) wandb.log({"train/global_grad_norm": global_norm, "step": step}, step=step) # --------------------------- # Prepare a real-image reference folder for FID (folder-vs-folder) # --------------------------- def ensure_real_ref_folder(dl, out_dir, max_images=50000, img_size=32, force_rgb=False): """ Exports up to `max_images` real images from the dataloader to `out_dir` in PNG format for FID reference. - MNIST is 1-channel; some FID/IS tools expect 3-channel -> set force_rgb=True to replicate channels. - Images are already [0,1] tensors from dataloader. """ os.makedirs(out_dir, exist_ok=True) # If already exists with enough images, skip existing = [f for f in os.listdir(out_dir) if f.lower().endswith(".png")] if len(existing) >= max_images // 10: # heuristic to avoid re-dumping fully return saved = 0 idx = 0 for x, _ in dl: # x: [B, C, H, W] in [0,1] if force_rgb and x.shape[1] == 1: x = x.repeat(1, 3, 1, 1) for i in range(x.size(0)): save_image(x[i], os.path.join(out_dir, f"{idx:06d}.png")) idx += 1 saved += 1 if saved >= max_images: return # --------------------------- # Generate a set of images for metrics # --------------------------- @torch.inference_mode() def generate_images_to_folder(model, n_images=5000, batch_size=64, out_dir="./gen_eval", force_rgb=True): """ Uses the (EMA) diffusion sampler to generate `n_images` and save as PNGs. Optionally tile grayscale to RGB to satisfy metric toolchains. """ os.makedirs(out_dir, exist_ok=True) saved = 0 idx = 0 while saved < n_images: cur = min(batch_size, n_images - saved) imgs = model.sample(cur) # in [0,1], shape [B, C, H, W] if force_rgb and imgs.shape[1] == 1: imgs = imgs.repeat(1, 3, 1, 1) for i in range(cur): save_image(imgs[i], os.path.join(out_dir, f"{idx:06d}.png")) idx += 1 saved += cur # --------------------------- # Compute FID (clean-fid) and IS (torch-fidelity) # --------------------------- def compute_fid_cleanfid(gen_dir, real_dir): if not HAS_CLEANFID: print("[metrics] clean-fid not installed; skip FID.") return None try: score = clean_fid.compute_fid(gen_dir, real_dir) return float(score) except Exception as e: print("[metrics] clean-fid error:", e) return None def compute_inception_score_torchfidelity(gen_dir, cuda=True): if not HAS_TORCH_FIDELITY: print("[metrics] torch-fidelity not installed; skip IS.") return None, None try: metrics = tf_calculate_metrics( input1=gen_dir, cuda=cuda and torch.cuda.is_available(), isc=True, fid=False, kid=False, prc=False ) # returns mean and std return float(metrics.get("inception_score_mean", float("nan"))), float(metrics.get("inception_score_std", float("nan"))) except Exception as e: print("[metrics] torch-fidelity error:", e) return None, None # --------------------------- # Main training # --------------------------- def main(cfg_path="config_mnist_small.yaml", seed=42): torch.manual_seed(seed) # Load config and setup cfg = yaml.safe_load(open(cfg_path)) os.makedirs(cfg["train"]["ckpt_dir"], exist_ok=True) os.makedirs("./samples", exist_ok=True) maybe_enable_cuda_speedups(cfg) device = "cuda" if torch.cuda.is_available() else "cpu" # W&B init run = None if cfg["wandb"]["enabled"]: if cfg["wandb"].get("mode", "online") == "offline": os.environ["WANDB_MODE"] = "offline" wandb.login() run = wandb.init( project=cfg["project"], name=cfg["run_name"], config=cfg, tags=cfg["wandb"].get("tags", []) ) # Log diffusion hyperparameters once for visibility wandb.config.update({ "hparams/T": cfg["diffusion"]["T"], "hparams/beta_schedule": cfg["diffusion"]["beta_schedule"], "hparams/sampling_steps": cfg["diffusion"]["sampling_steps"], "hparams/eta": cfg["diffusion"]["eta"], }, allow_val_change=True) # Data dl = get_loader_mnist(cfg["data"]["batch_size"], cfg["data"]["num_workers"], cfg["data"]["image_size"]) # Model + Diffusion (FP32 default) unet = UNet( dim=cfg["model"]["dim"], dim_mults=tuple(cfg["model"]["dim_mults"]), channels=cfg["model"]["channels"], attn_heads=cfg["model"]["attn_heads"], attn_dim_head=cfg["model"]["attn_dim_head"], dropout=cfg["model"]["dropout"], self_condition=cfg["model"]["self_condition"], learned_variance=cfg["model"]["learned_variance"], outer_attn=cfg["model"]["outer_attn"], ).to(device) diffusion = GaussianDiffusion( unet, image_size=(cfg["data"]["image_size"], cfg["data"]["image_size"]), timesteps=cfg["diffusion"]["T"], beta_schedule=cfg["diffusion"]["beta_schedule"], objective=cfg["diffusion"]["objective"], sampling_steps=cfg["diffusion"]["sampling_steps"], eta=cfg["diffusion"]["eta"], self_condition=cfg["diffusion"]["self_condition"], auto_normalize=True, clamp_x0=cfg["diffusion"]["clamp_x0"] ).to(device) # Optimizer (FP32) opt = Adam(diffusion.parameters(), lr=cfg["opt"]["lr"], betas=tuple(cfg["opt"]["betas"])) # EMA (recommended) ema = None if cfg.get("ema", {}).get("enabled", True): ema = EMA(diffusion, beta=cfg["ema"]["decay"], update_every=cfg["ema"]["update_every"]) ema.to(device) # Train loop params max_steps = int(cfg["train"]["max_steps"]) log_every = int(cfg["train"]["log_every"]) grad_accum = int(cfg["train"].get("grad_accum", 1)) # Norm logging params (you can add into YAML under "metrics") global_norm_every = int( cfg.get("metrics", {}).get("global_norm_every", 1000)) # Metric config (FID / IS) enable_fid = bool(cfg.get("metrics", {}).get("enable_fid", False)) enable_is = bool(cfg.get("metrics", {}).get("enable_is", False)) fid_every = int(cfg.get("metrics", {}).get("fid_every", 4000)) is_every = int(cfg.get("metrics", {}).get("is_every", 4000)) metric_n_gen = int(cfg.get("metrics", {}).get("metric_num_gen", 5000)) metric_bs = int(cfg.get("metrics", {}).get("metric_batch_size", 64)) # Speed tracking (iterations/sec) step = 0 pbar = tqdm(total=max_steps, desc="training") opt.zero_grad(set_to_none=True) # For IPS calculation over logging window last_log_time = time.perf_counter() last_log_step = 0 # Main loop while step < max_steps: for x, _ in dl: # Move batch to device and force float32 x = x.to(device, non_blocking=True).float() # Standard FP32 forward/backward (no AMP) loss = diffusion(x) / grad_accum loss.backward() if ((step + 1) % grad_accum) == 0: # Clip gradients torch.nn.utils.clip_grad_norm_( diffusion.parameters(), cfg["opt"]["grad_clip"]) # Optimizer update opt.step() opt.zero_grad(set_to_none=True) # EMA update if ema is not None: ema.update() step += 1 pbar.update(1) # -------- sparse scalar logging -------- if run and step % log_every == 0: # training speed over last window now = time.perf_counter() delta_t = max(now - last_log_time, 1e-6) delta_s = step - last_log_step ips = delta_s / delta_t ms_per_iter = 1000.0 / max(ips, 1e-9) wandb.log({ "train/loss": float(loss.item() * grad_accum), "speed/iter_per_sec": ips, "speed/ms_per_iter": ms_per_iter, "step": step }, step=step) # reset window last_log_time = now last_log_step = step # -------- sparse norm logging -------- if run: log_global_grad_norm_sparsely( diffusion, step, every=global_norm_every) # -------- periodic sampling (with speed) -------- if step % int(cfg["diffusion"]["sample_every"]) == 0: diffusion.eval() with torch.inference_mode(): sampler = ema.ema_model if ema is not None else diffusion t0 = time.perf_counter() samples = sampler.sample(cfg["diffusion"]["sample_n"]) t1 = time.perf_counter() path = f"./samples/mnist_step_{step}.png" save_image(samples, path, nrow=8) # sampling speed: imgs/sec for this batch dt = max(t1 - t0, 1e-6) imgs_per_sec = cfg["diffusion"]["sample_n"] / dt if run: wandb.log({ "samples_grid": wandb.Image(path), "speed/sampling_imgs_per_sec": imgs_per_sec, "speed/sampling_sec": dt, "step": step }, step=step) # (a) normal sample grid + timing t0 = time.perf_counter() samples = sampler.sample(cfg["diffusion"]["sample_n"]) t1 = time.perf_counter() path = f"./samples/mnist_step_{step}.png" save_image(samples, path, nrow=8) dt = max(t1 - t0, 1e-6) imgs_per_sec = cfg["diffusion"]["sample_n"] / dt if run: wandb.log({ "samples_grid": wandb.Image(path), "speed/sampling_imgs_per_sec": imgs_per_sec, "speed/sampling_sec": dt, "step": step }, step=step) # (b) reverse trajectory video (xt over denoising) if cfg.get("viz", {}).get("enable_reverse_traj", False) \ and step % int(cfg["viz"]["reverse_every_steps"]) == 0: B = int(cfg["viz"]["reverse_batch_n"]) C = diffusion.channels H = W = cfg["data"]["image_size"] # always use DDPM trajectory here; if you want DDIM, implement a similar function _, frames_xt, _ = sampler.ddpm_sample_trajectory( shape=(B, C, H, W), record_every=int( cfg["viz"]["reverse_record_every"]), return_x0=False ) video = frames_to_wandb_video( frames_xt, nrow=min(8, B), fps=int(cfg["viz"]["video_fps"])) if run: wandb.log({"viz/reverse_xt": video, "step": step}, step=step) # (c) forward noising trajectory (q(x_t|x0)) if cfg.get("viz", {}).get("enable_forward_traj", False) \ and step % int(cfg["viz"]["forward_every_steps"]) == 0: # take a small batch from the current batch `x` Bf = int(cfg["viz"]["forward_batch_n"]) # x is in [0,1] from dataloader x0_vis = x[:Bf].detach().cpu() t_vals = cfg["viz"]["forward_t_values"] # list of ints frames_fwd = diffusion.forward_noising_trajectory( x0=x0_vis.to(device), t_values=t_vals ) video_fwd = frames_to_wandb_video( frames_fwd, nrow=min(8, Bf), fps=int(cfg["viz"]["video_fps"])) if run: wandb.log({"viz/forward_xt": video_fwd, "step": step}, step=step) diffusion.train() # -------- sparse checkpointing -------- if step % (5 * int(cfg["diffusion"]["sample_every"])) == 0: save_obj = { "step": step, "model": diffusion.state_dict(), "opt": opt.state_dict()} if ema is not None: save_obj["ema"] = ema.state_dict() torch.save(save_obj, os.path.join( cfg["train"]["ckpt_dir"], f"mnist_step_{step}.pt")) # -------- optional FID & IS evaluation (thưa, tốn thời gian) -------- # Uses folder-vs-folder: generate N images -> compare to a real-image folder we export once. if (enable_fid or enable_is) and (step % min(fid_every if enable_fid else is_every, is_every if enable_is else fid_every) == 0): # Export real images (once) as reference real_ref_dir = "./metrics_ref/mnist_train_32_rgb" ensure_real_ref_folder(dl, real_ref_dir, max_images=50000, img_size=cfg["data"]["image_size"], force_rgb=True) # Generate a fresh set for metrics gen_dir = f"./metrics_gen/step_{step}" sampler = ema.ema_model if ema is not None else diffusion t0 = time.perf_counter() with torch.inference_mode(): generate_images_to_folder(sampler, n_images=metric_n_gen, batch_size=metric_bs, out_dir=gen_dir, force_rgb=True) t1 = time.perf_counter() gen_fps = metric_n_gen / max(t1 - t0, 1e-6) log_payload = {"step": step, "metrics/gen_imgs_per_sec": gen_fps} if enable_fid and HAS_CLEANFID and (step % fid_every == 0): fid_score = compute_fid_cleanfid(gen_dir, real_ref_dir) if fid_score is not None: log_payload["metrics/FID_clean"] = fid_score if enable_is and HAS_TORCH_FIDELITY and (step % is_every == 0): is_mean, is_std = compute_inception_score_torchfidelity( gen_dir, cuda=True) if is_mean is not None: log_payload["metrics/IS_mean"] = is_mean if is_std is not None: log_payload["metrics/IS_std"] = is_std if run and len(log_payload) > 1: wandb.log(log_payload, step=step) if step >= max_steps: break pbar.close() if run: run.finish() if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--config", type=str, default="config_mnist_small.yaml", help="Path to YAML config") args = ap.parse_args() main(args.config)