"""ShotCraft — JSON schemas & validation for Stage 1 output (FR-1.2, FR-1.4).""" from __future__ import annotations import json from dataclasses import dataclass, field, asdict CATEGORIES = ["Fashion", "Beauty", "Food & Beverage", "Electronics", "Home", "Jewelry", "Other"] STYLE_PRESETS = ["Minimal", "Luxury", "Lifestyle", "Bold & Colorful", "Editorial"] # Shared style suffixes injected into every FLUX prompt (FR-2.3) STYLE_SUFFIXES = { "Minimal": "clean minimal composition, soft studio light, negative space, commercial product photography", "Luxury": "premium editorial look, dramatic rim light, rich shadows, high-end advertising photography", "Lifestyle": "natural candid setting, warm daylight, shallow depth of field, lifestyle product photography", "Bold & Colorful": "vivid saturated backdrop colors, hard light, playful geometric backdrop, pop-art product photography", "Editorial": "magazine editorial styling, cinematic lighting, artful composition, fashion campaign photography", } @dataclass class Shot: id: int concept_name: str scene: str camera_angle: str lighting: str color_palette: list[str] # hex strings props: str marketing_angle: str image_prompt: str # FLUX-optimized, English @dataclass class ProductAnalysis: product_type: str materials: str colors: list[str] distinguishing_features: str # Locked one-sentence product identity, prefixed to every FLUX prompt so # the SAME product appears in all shots. Backfilled when the model omits it. canonical_description: str = "" @dataclass class ConceptPackage: product_analysis: ProductAnalysis shots: list[Shot] = field(default_factory=list) def to_json(self) -> str: return json.dumps(asdict(self), indent=2, ensure_ascii=False) REQUIRED_SHOT_KEYS = {"id", "concept_name", "scene", "camera_angle", "lighting", "color_palette", "props", "marketing_angle", "image_prompt"} def validate_package(raw: str | dict) -> ConceptPackage: """Parse + validate Stage 1 model output. Raises ValueError with a readable message.""" data = json.loads(raw) if isinstance(raw, str) else raw pa = data.get("product_analysis") if not isinstance(pa, dict): raise ValueError("missing product_analysis object") shots = data.get("shots") if not isinstance(shots, list) or len(shots) != 5: raise ValueError(f"expected exactly 5 shots, got {len(shots) if isinstance(shots, list) else 'none'}") parsed = [] for i, s in enumerate(shots): # Tolerate a missing color_palette (observed MiniCPM flake): backfill # from the detected product colors instead of failing the package. if "color_palette" not in s and isinstance(pa.get("colors"), list): s["color_palette"] = list(pa["colors"]) missing = REQUIRED_SHOT_KEYS - set(s) if missing: raise ValueError(f"shot {i+1} missing keys: {sorted(missing)}") parsed.append(Shot(**{k: s[k] for k in REQUIRED_SHOT_KEYS})) return ConceptPackage( product_analysis=ProductAnalysis( product_type=pa.get("product_type", ""), materials=pa.get("materials", ""), colors=pa.get("colors", []), distinguishing_features=pa.get("distinguishing_features", ""), canonical_description=_canonical_description(pa), ), shots=parsed, ) def _canonical_description(pa: dict) -> str: """Locked product identity. Backfill from the analysis fields when the model omits it so Stage 2 can always pin the product look.""" desc = str(pa.get("canonical_description") or "").strip() if desc: return desc bits = [str(pa.get("product_type") or "").strip()] colors = pa.get("colors") if isinstance(colors, list) and colors: names = [_color_name(str(c)) for c in colors[:5]] bits.append("in " + ", ".join(dict.fromkeys(names))) # dedupe, keep order if pa.get("materials"): bits.append(f"made of {pa['materials']}") if pa.get("distinguishing_features"): bits.append(str(pa["distinguishing_features"]).strip()) return ", ".join(b for b in bits if b) # FLUX barely understands hex codes - name them for the backfilled description. _NAMED_COLORS = [ ("white", (255, 255, 255)), ("off-white", (240, 238, 230)), ("light grey", (200, 200, 200)), ("grey", (128, 128, 128)), ("charcoal", (60, 60, 60)), ("black", (10, 10, 10)), ("red", (220, 40, 40)), ("orange", (240, 140, 30)), ("yellow", (245, 200, 40)), ("gum brown", (170, 120, 70)), ("brown", (110, 70, 40)), ("green", (60, 160, 70)), ("teal", (20, 160, 170)), ("blue", (50, 90, 200)), ("navy blue", (25, 35, 80)), ("purple", (130, 70, 190)), ("pink", (235, 120, 170)), ("beige", (220, 200, 170)), ] def _color_name(hex_str: str) -> str: """Nearest plain-English name for a hex color; passes through non-hex.""" h = hex_str.strip().lstrip("#") if len(h) == 3: h = "".join(ch * 2 for ch in h) if len(h) < 6 or any(ch not in "0123456789aAbBcCdDeEfF" for ch in h[:6]): return hex_str # already a name or unparseable - keep as-is r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) return min(_NAMED_COLORS, key=lambda nc: (nc[1][0]-r)**2 + (nc[1][1]-g)**2 + (nc[1][2]-b)**2)[0]