mlx-nucleus-image / convert_and_upload.py
treadon's picture
Upload convert_and_upload.py with huggingface_hub
e60c3e2 verified
Raw
History Blame
3.95 kB
#!/usr/bin/env python3
"""Convert Nucleus-Image weights to MLX format and upload directly to HuggingFace.
Avoids saving to disk (no disk space needed). Converts in memory and streams to HF.
"""
import argparse
import json
import shutil
import tempfile
from pathlib import Path
import mlx.core as mx
from huggingface_hub import HfApi, snapshot_download
def convert_vae_weights(raw_vae: dict) -> dict:
"""Convert VAE weights: CausalConv3d->Conv2d (last temporal slice), transpose to NHWC."""
vae_w = {}
for k, v in raw_vae.items():
if k.startswith("encoder.") or k.startswith("quant_conv"):
continue
if k.startswith("latents_") or k in ("spatial_scale_factor", "temporal_scale_factor"):
continue
if k.startswith("bn."):
continue
if "weight" in k and v.ndim == 5:
D = v.shape[2]
v = v[:, :, -1, :, :] if D > 1 else v.squeeze(2)
v = v.transpose(0, 2, 3, 1)
elif "weight" in k and v.ndim == 4:
v = v.transpose(0, 2, 3, 1)
if "gamma" in k:
v = v.squeeze()
vae_w[k] = v
return vae_w
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--source", default="NucleusAI/Nucleus-Image")
parser.add_argument("--dest", default="treadon/mlx-nucleus-image")
args = parser.parse_args()
api = HfApi()
src = Path(snapshot_download(args.source))
# ── Upload configs ──
print("Uploading configs...")
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f:
json.dump(json.load(open(src / "transformer" / "config.json")), f)
f.flush()
api.upload_file(path_or_fileobj=f.name, path_in_repo="dit/config.json", repo_id=args.dest)
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f:
json.dump(json.load(open(src / "vae" / "config.json")), f)
f.flush()
api.upload_file(path_or_fileobj=f.name, path_in_repo="vae/config.json", repo_id=args.dest)
# ── Upload DiT weights shard by shard ──
# DiT weights are already in the right format for MLX (bfloat16 linear weights).
# Just upload the original shards with a consistent naming.
dit_shards = sorted((src / "transformer").glob("*.safetensors"))
print(f"Uploading {len(dit_shards)} DiT weight shards...")
if len(dit_shards) == 1:
api.upload_file(
path_or_fileobj=str(dit_shards[0]),
path_in_repo="dit/weights.safetensors",
repo_id=args.dest,
)
else:
for i, shard in enumerate(dit_shards):
print(f" Uploading shard {i+1}/{len(dit_shards)}: {shard.name}")
api.upload_file(
path_or_fileobj=str(shard),
path_in_repo=f"dit/{shard.name}",
repo_id=args.dest,
)
# Also upload the index file if it exists
idx = src / "transformer" / "diffusion_pytorch_model.safetensors.index.json"
if idx.exists():
api.upload_file(
path_or_fileobj=str(idx),
path_in_repo="dit/weights.index.json",
repo_id=args.dest,
)
# ── Convert and upload VAE ──
print("Converting VAE weights (Conv3d->Conv2d)...")
raw_vae = mx.load(str(src / "vae" / "diffusion_pytorch_model.safetensors"))
vae_w = convert_vae_weights(raw_vae)
print(f" {len(vae_w)} tensors converted")
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as f:
mx.save_safetensors(f.name, vae_w)
print(f" VAE saved to temp ({Path(f.name).stat().st_size / 1e6:.0f} MB)")
api.upload_file(
path_or_fileobj=f.name,
path_in_repo="vae/weights.safetensors",
repo_id=args.dest,
)
print(f"\nDone! Weights uploaded to {args.dest}")
if __name__ == "__main__":
main()