| """ShotCraft — JSON schemas & validation for Stage 1 output (FR-1.2, FR-1.4).""" |
| from __future__ import annotations |
| import json |
| from dataclasses import dataclass, field, asdict |
|
|
| CATEGORIES = ["Fashion", "Beauty", "Food & Beverage", "Electronics", "Home", "Jewelry", "Other"] |
| STYLE_PRESETS = ["Minimal", "Luxury", "Lifestyle", "Bold & Colorful", "Editorial"] |
|
|
| |
| STYLE_SUFFIXES = { |
| "Minimal": "clean minimal composition, soft studio light, negative space, commercial product photography", |
| "Luxury": "premium editorial look, dramatic rim light, rich shadows, high-end advertising photography", |
| "Lifestyle": "natural candid setting, warm daylight, shallow depth of field, lifestyle product photography", |
| "Bold & Colorful": "vivid saturated backdrop colors, hard light, playful geometric backdrop, pop-art product photography", |
| "Editorial": "magazine editorial styling, cinematic lighting, artful composition, fashion campaign photography", |
| } |
|
|
| @dataclass |
| class Shot: |
| id: int |
| concept_name: str |
| scene: str |
| camera_angle: str |
| lighting: str |
| color_palette: list[str] |
| props: str |
| marketing_angle: str |
| image_prompt: str |
|
|
| @dataclass |
| class ProductAnalysis: |
| product_type: str |
| materials: str |
| colors: list[str] |
| distinguishing_features: str |
| |
| |
| canonical_description: str = "" |
|
|
| @dataclass |
| class ConceptPackage: |
| product_analysis: ProductAnalysis |
| shots: list[Shot] = field(default_factory=list) |
|
|
| def to_json(self) -> str: |
| return json.dumps(asdict(self), indent=2, ensure_ascii=False) |
|
|
| REQUIRED_SHOT_KEYS = {"id", "concept_name", "scene", "camera_angle", "lighting", |
| "color_palette", "props", "marketing_angle", "image_prompt"} |
|
|
| def validate_package(raw: str | dict) -> ConceptPackage: |
| """Parse + validate Stage 1 model output. Raises ValueError with a readable message.""" |
| data = json.loads(raw) if isinstance(raw, str) else raw |
| pa = data.get("product_analysis") |
| if not isinstance(pa, dict): |
| raise ValueError("missing product_analysis object") |
| shots = data.get("shots") |
| if not isinstance(shots, list) or len(shots) != 5: |
| raise ValueError(f"expected exactly 5 shots, got {len(shots) if isinstance(shots, list) else 'none'}") |
| parsed = [] |
| for i, s in enumerate(shots): |
| |
| |
| if "color_palette" not in s and isinstance(pa.get("colors"), list): |
| s["color_palette"] = list(pa["colors"]) |
| missing = REQUIRED_SHOT_KEYS - set(s) |
| if missing: |
| raise ValueError(f"shot {i+1} missing keys: {sorted(missing)}") |
| parsed.append(Shot(**{k: s[k] for k in REQUIRED_SHOT_KEYS})) |
| return ConceptPackage( |
| product_analysis=ProductAnalysis( |
| product_type=pa.get("product_type", ""), |
| materials=pa.get("materials", ""), |
| colors=pa.get("colors", []), |
| distinguishing_features=pa.get("distinguishing_features", ""), |
| canonical_description=_canonical_description(pa), |
| ), |
| shots=parsed, |
| ) |
|
|
| def _canonical_description(pa: dict) -> str: |
| """Locked product identity. Backfill from the analysis fields when the |
| model omits it so Stage 2 can always pin the product look.""" |
| desc = str(pa.get("canonical_description") or "").strip() |
| if desc: |
| return desc |
| bits = [str(pa.get("product_type") or "").strip()] |
| colors = pa.get("colors") |
| if isinstance(colors, list) and colors: |
| names = [_color_name(str(c)) for c in colors[:5]] |
| bits.append("in " + ", ".join(dict.fromkeys(names))) |
| if pa.get("materials"): |
| bits.append(f"made of {pa['materials']}") |
| if pa.get("distinguishing_features"): |
| bits.append(str(pa["distinguishing_features"]).strip()) |
| return ", ".join(b for b in bits if b) |
|
|
| |
| _NAMED_COLORS = [ |
| ("white", (255, 255, 255)), ("off-white", (240, 238, 230)), |
| ("light grey", (200, 200, 200)), ("grey", (128, 128, 128)), |
| ("charcoal", (60, 60, 60)), ("black", (10, 10, 10)), |
| ("red", (220, 40, 40)), ("orange", (240, 140, 30)), |
| ("yellow", (245, 200, 40)), ("gum brown", (170, 120, 70)), |
| ("brown", (110, 70, 40)), ("green", (60, 160, 70)), |
| ("teal", (20, 160, 170)), ("blue", (50, 90, 200)), |
| ("navy blue", (25, 35, 80)), ("purple", (130, 70, 190)), |
| ("pink", (235, 120, 170)), ("beige", (220, 200, 170)), |
| ] |
|
|
| def _color_name(hex_str: str) -> str: |
| """Nearest plain-English name for a hex color; passes through non-hex.""" |
| h = hex_str.strip().lstrip("#") |
| if len(h) == 3: |
| h = "".join(ch * 2 for ch in h) |
| if len(h) < 6 or any(ch not in "0123456789aAbBcCdDeEfF" for ch in h[:6]): |
| return hex_str |
| r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) |
| return min(_NAMED_COLORS, |
| key=lambda nc: (nc[1][0]-r)**2 + (nc[1][1]-g)**2 + (nc[1][2]-b)**2)[0] |
|
|