NEXUS_Visual_Weaver / modal_nexus_refine_v2.py
specimba's picture
feat: real Modal refinement with multi-LoRA, A100 GPU, LoRA registry - wired not mocked
7a7354f verified
Raw
History Blame Contribute Delete
13.4 kB
"""
NEXUS Visual Weaver β€” Modal Refinement Pipeline v2
===================================================
Real FLUX.1-Kontext-dev img2img refinement with multi-LoRA on Modal.
GPU options: A100-80GB, A100-40GB, L40S, T4
LoRA adapters: NO8D/BodyControl, NO8D/ExpressionControl, fal/realism-detailer,
ilkerzgi/metallic, ilkerzgi/glittering-portrait, ilkerzgi/embroidery-patch
Usage:
modal run modal_nexus_refine_v2.py --image-path input.png
Or call remotely from HF Space:
fn = modal.Function.lookup("nexus-couture-refine-v2", "refine_couture")
result_bytes = fn.remote(image_bytes=..., lora_adapters=["garment", "hardware"])
"""
import modal
from io import BytesIO
from PIL import Image
from typing import List, Optional
app = modal.App("nexus-couture-refine-v2")
# ─── Image with all dependencies for FLUX Kontext + LoRA ───
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("git", "libgl1-mesa-glx", "libglib2.0-0")
.pip_install(
"torch==2.5.1",
"torchvision==0.20.1",
"diffusers>=0.32.0",
"transformers>=4.45.0",
"accelerate>=1.1.0",
"safetensors",
"Pillow",
"huggingface-hub",
"peft>=0.13.0",
"protobuf",
"sentencepiece",
)
)
# Persistent volume for model caching (saves startup time & bandwidth)
volume = modal.Volume.from_name("nexus-model-cache", create_if_missing=True)
# ─── NEXUS Taste Profile β€” The "Soul" of the generator ───
NEXUS_CORE_STYLE = (
"Slavic woman, rain-slick neon cyberpunk city at night, long structured black patent leather coat, "
"faux fur collar, Chantilly lace neckline, glowing crimson hardware, platform boots, "
"floating NEXUS sigils and code streams, ultra detailed wet fabric texture, cinematic lighting, "
"high fashion editorial, photorealistic, 8k"
)
# ─── LoRA Adapter Registry ───
# Maps short names to HF repo IDs for the Space UI
LORA_REGISTRY = {
"garment": {
"repo_id": "NO8D/BodyControl",
"adapter_name": "garment_control",
"weight": 0.75,
"description": "Body/garment shape control for FLUX",
},
"hardware": {
"repo_id": "NO8D/ExpressionControl",
"adapter_name": "expression_control",
"weight": 0.70,
"description": "Expression/hardware detail control",
},
"realism": {
"repo_id": "fal/realism-detailer",
"adapter_name": "realism_detail",
"weight": 0.60,
"description": "Photorealistic detail enhancement",
},
"metallic": {
"repo_id": "ilkerzgi/metallic",
"adapter_name": "metallic_finish",
"weight": 0.55,
"description": "Metallic material finish (hardware, buckles)",
},
"glittering": {
"repo_id": "ilkerzgi/glittering-portrait",
"adapter_name": "glittering_portrait",
"weight": 0.55,
"description": "Glittering/sparkling portrait effects",
},
"embroidery": {
"repo_id": "ilkerzgi/embroidery-patch",
"adapter_name": "embroidery_patch",
"weight": 0.55,
"description": "Embroidery and patch textures on garments",
},
}
# GPU pricing for cost tracker (USD per hour)
GPU_PRICING = {
"A100-80GB": 1.80,
"A100-40GB": 1.10,
"L40S": 1.05,
"T4": 0.40,
}
# Map GPU names to Modal GPU identifiers
GPU_MAP = {
"A100-80GB": "A100",
"A100-40GB": "A10G", # Modal A10G is the closest to A100-40GB
"L40S": "L40S",
"T4": "T4",
}
def _get_lora_adapters(adapter_keys: Optional[List[str]] = None) -> List[dict]:
"""Resolve LoRA adapter keys to full config dicts."""
if not adapter_keys:
return []
adapters = []
for key in adapter_keys:
key = key.strip().lower()
if key in LORA_REGISTRY:
adapters.append(LORA_REGISTRY[key])
else:
print(f"⚠️ Unknown LoRA adapter key: {key}, skipping")
return adapters
@app.function(
image=image,
gpu="A100", # Default to A100-80GB for best performance
volumes={"/cache": volume},
timeout=600, # 10 minutes max per run
allow_concurrent_inputs=4,
)
def refine_couture(
image_bytes: bytes,
user_addition: str = "",
strength: float = 0.58,
steps: int = 32,
guidance_scale: float = 3.8,
seed: int = -1,
lora_adapters: Optional[List[str]] = None,
negative_prompt: str = "blurry, low quality, deformed, extra limbs, bad anatomy, watermark, text",
gpu_type: str = "A100-80GB",
) -> bytes:
"""
Refines an input image using FLUX.1-Kontext-dev with optional multi-LoRA.
Preserves the core NEXUS aesthetic while applying user modifications.
Args:
image_bytes: Input image as PNG/JPEG bytes
user_addition: Additional prompt text to append to NEXUS core style
strength: img2img strength (0.0-1.0, higher = more change)
steps: Number of inference steps
guidance_scale: Classifier-free guidance scale
seed: Random seed (-1 for random)
lora_adapters: List of adapter keys: "garment", "hardware", "realism",
"metallic", "glittering", "embroidery"
negative_prompt: Negative prompt for generation
gpu_type: GPU to use (A100-80GB, A100-40GB, L40S, T4)
Returns:
PNG image bytes of the refined result
"""
import torch
from diffusers import FluxKontextPipeline
import time
started = time.time()
print(f"🎨 NEXUS Kontext Refinement v2")
print(f" GPU: {gpu_type} | Strength: {strength} | Steps: {steps} | Guidance: {guidance_scale}")
print(f" LoRA adapters requested: {lora_adapters}")
# ─── Load Pipeline ───
print("⏳ Loading FLUX.1-Kontext-dev pipeline...")
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16,
cache_dir="/cache",
).to("cuda")
# Enable memory efficient attention
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
print(" ℹ️ xformers not available, using default attention")
# ─── Load LoRA Adapters ───
adapters = _get_lora_adapters(lora_adapters)
loaded_adapters = []
if adapters:
print(f"πŸ”Œ Loading {len(adapters)} LoRA adapter(s)...")
for adapter_cfg in adapters:
try:
print(f" Loading: {adapter_cfg['repo_id']} ({adapter_cfg['adapter_name']})")
pipe.load_lora_weights(
adapter_cfg["repo_id"],
adapter_name=adapter_cfg["adapter_name"],
)
loaded_adapters.append(adapter_cfg)
print(f" βœ… Loaded: {adapter_cfg['adapter_name']}")
except Exception as e:
print(f" ❌ Failed to load {adapter_cfg['repo_id']}: {e}")
print(f" ⚠️ Continuing without this adapter")
# Activate all loaded adapters with their weights
if loaded_adapters:
adapter_names = [a["adapter_name"] for a in loaded_adapters]
adapter_weights = [a["weight"] for a in loaded_adapters]
try:
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
print(f" βœ… Activated {len(loaded_adapters)} adapter(s): {adapter_names}")
except Exception as e:
print(f" ⚠️ Could not set multi-adapter weights: {e}")
# Fallback: activate first adapter only
try:
pipe.set_adapters([loaded_adapters[0]["adapter_name"]],
adapter_weights=[loaded_adapters[0]["weight"]])
except Exception:
print(" ⚠️ Single adapter fallback also failed, using base model only")
# ─── Process Input Image ───
init_image = Image.open(BytesIO(image_bytes)).convert("RGB")
# Resize if too large (>2MP) to save VRAM/time
width, height = init_image.size
if width * height > 2_000_000:
scale = (2_000_000 / (width * height)) ** 0.5
new_size = (int(width * scale), int(height * scale))
init_image = init_image.resize(new_size, Image.LANCZOS)
print(f" πŸ“ Resized from {width}x{height} to {new_size[0]}x{new_size[1]}")
# ─── Construct Final Prompt ───
final_prompt = f"{NEXUS_CORE_STYLE}, {user_addition}" if user_addition else NEXUS_CORE_STYLE
# ─── Seed Handling ───
if seed == -1:
import random
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="cuda").manual_seed(seed)
print(f"🎯 Generating with seed {seed}")
print(f" Prompt: {final_prompt[:120]}...")
# ─── Run Inference ───
result = pipe(
image=init_image,
prompt=final_prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
strength=strength,
generator=generator,
).images[0]
# ─── Return as PNG bytes ───
buf = BytesIO()
result.save(buf, format="PNG")
elapsed = time.time() - started
print(f"βœ… Refinement complete in {elapsed:.1f}s")
return buf.getvalue()
@app.function(
image=image,
gpu="A100",
volumes={"/cache": volume},
timeout=600,
)
def check_modal_health() -> dict:
"""Quick health check β€” verifies Modal can load the pipeline."""
import torch
try:
cuda_available = torch.cuda.is_available()
gpu_name = torch.cuda.get_device_name(0) if cuda_available else "N/A"
gpu_mem = torch.cuda.get_device_properties(0).total_mem if cuda_available else 0
return {
"status": "healthy",
"cuda": cuda_available,
"gpu": gpu_name,
"gpu_memory_gb": round(gpu_mem / 1e9, 1),
"lora_registry": list(LORA_REGISTRY.keys()),
"gpu_pricing": GPU_PRICING,
}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.function(
image=image,
gpu="A100",
volumes={"/cache": volume},
timeout=900,
)
def generate_from_text(
prompt: str,
user_addition: str = "",
width: int = 1024,
height: int = 1024,
steps: int = 4,
guidance_scale: float = 1.0,
seed: int = -1,
lora_adapters: Optional[List[str]] = None,
) -> bytes:
"""
Generate a new image from text using FLUX.2-Klein-9B with optional LoRA.
For the Space's primary generation (no input image needed).
"""
import torch
from diffusers import Flux2KleinPipeline
import random
print("🎨 NEXUS Text-to-Image Generation (Modal)")
pipe = Flux2KleinPipeline.from_pretrained(
"black-forest-labs/FLUX.2-klein-9B",
torch_dtype=torch.bfloat16,
cache_dir="/cache",
).to("cuda")
# Load LoRA adapters if specified
adapters = _get_lora_adapters(lora_adapters)
loaded = []
for adapter_cfg in adapters:
try:
pipe.load_lora_weights(adapter_cfg["repo_id"], adapter_name=adapter_cfg["adapter_name"])
loaded.append(adapter_cfg)
except Exception as e:
print(f"⚠️ Failed to load LoRA {adapter_cfg['repo_id']}: {e}")
if loaded:
try:
pipe.set_adapters(
[a["adapter_name"] for a in loaded],
adapter_weights=[a["weight"] for a in loaded],
)
except Exception:
pass
if seed == -1:
seed = random.randint(0, 2**32 - 1)
generator = torch.Generator(device="cuda").manual_seed(seed)
final_prompt = f"{NEXUS_CORE_STYLE}, {user_addition}" if user_addition else prompt
result = pipe(
prompt=final_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=generator,
).images[0]
buf = BytesIO()
result.save(buf, format="PNG")
return buf.getvalue()
@app.local_entrypoint()
def test_refine(
image_path: str = "test_input.png",
output_path: str = "test_output.png",
user_prompt: str = "glowing crimson buckles, wet pavement reflection",
loras: str = "garment,realism",
):
"""Local test entrypoint β€” runs the refinement on Modal"""
from pathlib import Path
if not Path(image_path).exists():
print(f"❌ Input image not found: {image_path}")
print("Creating a dummy 512x512 test image...")
test_img = Image.new("RGB", (512, 512), color=(30, 10, 50))
buf = BytesIO()
test_img.save(buf, format="PNG")
image_bytes = buf.getvalue()
else:
with open(image_path, "rb") as f:
image_bytes = f.read()
lora_list = [l.strip() for l in loras.split(",") if l.strip()] if loras else None
print("πŸš€ Sending to Modal A100 for refinement...")
result_bytes = refine_couture.remote(
image_bytes=image_bytes,
user_addition=user_prompt,
lora_adapters=lora_list,
strength=0.58,
steps=32,
)
with open(output_path, "wb") as f:
f.write(result_bytes)
print(f"βœ… Success! Output saved to {output_path}")