| """ | |
| Generate images from trained model | |
| """ | |
| import argparse | |
| import pickle | |
| import jax | |
| import jax.numpy as jnp | |
| import matplotlib.pyplot as plt | |
| import yaml | |
| from flax import nnx | |
| from jax.experimental import ode | |
| from model import DiT, DiTConfig | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", type=str, default="config.yaml", help="Path to config file" | |
| ) | |
| parser.add_argument( | |
| "--ckpt", type=str, default=None, help="Path to checkpoint file" | |
| ) | |
| parser.add_argument("--seed", type=int, default=0, help="Random seed") | |
| return parser.parse_args() | |
| def load_config(config_path): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def sample_images(graphdef, state, rng): | |
| flow = nnx.merge(graphdef, state) | |
| def flow_fn(y, t): | |
| o = flow(y, t[None]) | |
| return o | |
| x = jax.random.normal(rng, shape=(16, 64, 64, 3), dtype=jnp.float32) | |
| o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000)) | |
| o = jnp.clip(o[-1], 0, 1) | |
| return o | |
| def plot_new_images(graphdef, state, seed): | |
| images = sample_images(graphdef, state, nnx.Rngs(seed)()) | |
| plt.figure(figsize=(2, 2)) | |
| for i in range(16): | |
| plt.subplot(4, 4, i + 1) | |
| plt.imshow(images[i]) | |
| plt.axis("off") | |
| plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) | |
| plt.savefig(f"samples.png") | |
| plt.close() | |
| def main(): | |
| args = parse_args() | |
| config = load_config(args.config) | |
| dit_config = DiTConfig( | |
| input_dim=config["model"]["input_dim"], | |
| hidden_dim=config["model"]["hidden_dim"], | |
| num_blocks=config["model"]["num_blocks"], | |
| num_heads=config["model"]["num_heads"], | |
| patch_size=config["model"]["patch_size"], | |
| patch_stride=config["model"]["patch_stride"], | |
| time_freq_dim=config["model"]["time_freq_dim"], | |
| time_max_period=config["model"]["time_max_period"], | |
| mlp_ratio=config["model"]["mlp_ratio"], | |
| use_bias=config["model"]["use_bias"], | |
| padding=config["model"]["padding"], | |
| pos_embed_cls_token=config["model"]["pos_embed_cls_token"], | |
| pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"], | |
| ) | |
| abstract_flow = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0))) | |
| graphdef, _ = nnx.split(abstract_flow) | |
| with open(args.ckpt, "rb") as f: | |
| state = pickle.load(f, fix_imports=True) | |
| if "time_embedding" not in state: | |
| state = state[0] | |
| plot_new_images(graphdef, state, args.seed) | |
| if __name__ == "__main__": | |
| main() | |