| import gradio as gr |
| import jax |
| import jax.numpy as jnp |
| from jax.experimental import ode |
| import yaml |
| from flax import nnx |
| import pickle |
| import spaces |
|
|
|
|
| def load_model(config_path, ckpt_path): |
| |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| |
| with open(ckpt_path, "rb") as f: |
| leaves = pickle.load(f) |
|
|
| leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves) |
|
|
| from model import DiT, DiTConfig |
|
|
| dit_config = DiTConfig(**config["model"]) |
| model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0))) |
| graphdef, state = nnx.split(model) |
| _, treedef = jax.tree_util.tree_flatten(state) |
| state = jax.tree_util.tree_unflatten(treedef, leaves) |
| return graphdef, state |
|
|
|
|
| @jax.jit |
| def sample_images(graphdef, state, x0, t): |
| flow = nnx.merge(graphdef, state) |
|
|
| def flow_fn(y, t): |
| y = y.astype(jnp.bfloat16) |
| t = t.astype(jnp.bfloat16) |
| o = flow(y, t[None]) |
| return o.astype(jnp.float32) |
|
|
| o = ode.odeint(flow_fn, x0, t, rtol=1e-4) |
| o = jnp.clip(o[-1], 0, 1) |
| return o |
|
|
|
|
| @spaces.GPU |
| def generate_grid(seed, noise_level): |
| |
| graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl") |
|
|
| t = jnp.linspace(0, 1, 2) |
| x0 = jax.random.truncated_normal( |
| nnx.Rngs(seed)(), |
| -noise_level, |
| noise_level, |
| shape=(16, 64, 64, 3), |
| dtype=jnp.float32, |
| ) |
|
|
| |
| images = sample_images(graphdef, state, x0, t) |
|
|
| |
| rows = [] |
| for i in range(4): |
| row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1) |
| rows.append(row) |
| grid = jnp.concatenate(rows, axis=0) |
|
|
| return jax.device_get(grid) |
|
|
|
|
| |
| demo = gr.Interface( |
| fn=generate_grid, |
| inputs=[ |
| gr.Number(label="Random Seed", value=0, precision=0), |
| gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"), |
| ], |
| outputs=gr.Image(label="Generated Images"), |
| title="Anime Flow", |
| description="Generate a 4x4 grid of anime faces using Anime Flow", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|