File size: 5,404 Bytes
7362ed8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | """ShotCraft — JSON schemas & validation for Stage 1 output (FR-1.2, FR-1.4)."""
from __future__ import annotations
import json
from dataclasses import dataclass, field, asdict
CATEGORIES = ["Fashion", "Beauty", "Food & Beverage", "Electronics", "Home", "Jewelry", "Other"]
STYLE_PRESETS = ["Minimal", "Luxury", "Lifestyle", "Bold & Colorful", "Editorial"]
# Shared style suffixes injected into every FLUX prompt (FR-2.3)
STYLE_SUFFIXES = {
"Minimal": "clean minimal composition, soft studio light, negative space, commercial product photography",
"Luxury": "premium editorial look, dramatic rim light, rich shadows, high-end advertising photography",
"Lifestyle": "natural candid setting, warm daylight, shallow depth of field, lifestyle product photography",
"Bold & Colorful": "vivid saturated backdrop colors, hard light, playful geometric backdrop, pop-art product photography",
"Editorial": "magazine editorial styling, cinematic lighting, artful composition, fashion campaign photography",
}
@dataclass
class Shot:
id: int
concept_name: str
scene: str
camera_angle: str
lighting: str
color_palette: list[str] # hex strings
props: str
marketing_angle: str
image_prompt: str # FLUX-optimized, English
@dataclass
class ProductAnalysis:
product_type: str
materials: str
colors: list[str]
distinguishing_features: str
# Locked one-sentence product identity, prefixed to every FLUX prompt so
# the SAME product appears in all shots. Backfilled when the model omits it.
canonical_description: str = ""
@dataclass
class ConceptPackage:
product_analysis: ProductAnalysis
shots: list[Shot] = field(default_factory=list)
def to_json(self) -> str:
return json.dumps(asdict(self), indent=2, ensure_ascii=False)
REQUIRED_SHOT_KEYS = {"id", "concept_name", "scene", "camera_angle", "lighting",
"color_palette", "props", "marketing_angle", "image_prompt"}
def validate_package(raw: str | dict) -> ConceptPackage:
"""Parse + validate Stage 1 model output. Raises ValueError with a readable message."""
data = json.loads(raw) if isinstance(raw, str) else raw
pa = data.get("product_analysis")
if not isinstance(pa, dict):
raise ValueError("missing product_analysis object")
shots = data.get("shots")
if not isinstance(shots, list) or len(shots) != 5:
raise ValueError(f"expected exactly 5 shots, got {len(shots) if isinstance(shots, list) else 'none'}")
parsed = []
for i, s in enumerate(shots):
# Tolerate a missing color_palette (observed MiniCPM flake): backfill
# from the detected product colors instead of failing the package.
if "color_palette" not in s and isinstance(pa.get("colors"), list):
s["color_palette"] = list(pa["colors"])
missing = REQUIRED_SHOT_KEYS - set(s)
if missing:
raise ValueError(f"shot {i+1} missing keys: {sorted(missing)}")
parsed.append(Shot(**{k: s[k] for k in REQUIRED_SHOT_KEYS}))
return ConceptPackage(
product_analysis=ProductAnalysis(
product_type=pa.get("product_type", ""),
materials=pa.get("materials", ""),
colors=pa.get("colors", []),
distinguishing_features=pa.get("distinguishing_features", ""),
canonical_description=_canonical_description(pa),
),
shots=parsed,
)
def _canonical_description(pa: dict) -> str:
"""Locked product identity. Backfill from the analysis fields when the
model omits it so Stage 2 can always pin the product look."""
desc = str(pa.get("canonical_description") or "").strip()
if desc:
return desc
bits = [str(pa.get("product_type") or "").strip()]
colors = pa.get("colors")
if isinstance(colors, list) and colors:
names = [_color_name(str(c)) for c in colors[:5]]
bits.append("in " + ", ".join(dict.fromkeys(names))) # dedupe, keep order
if pa.get("materials"):
bits.append(f"made of {pa['materials']}")
if pa.get("distinguishing_features"):
bits.append(str(pa["distinguishing_features"]).strip())
return ", ".join(b for b in bits if b)
# FLUX barely understands hex codes - name them for the backfilled description.
_NAMED_COLORS = [
("white", (255, 255, 255)), ("off-white", (240, 238, 230)),
("light grey", (200, 200, 200)), ("grey", (128, 128, 128)),
("charcoal", (60, 60, 60)), ("black", (10, 10, 10)),
("red", (220, 40, 40)), ("orange", (240, 140, 30)),
("yellow", (245, 200, 40)), ("gum brown", (170, 120, 70)),
("brown", (110, 70, 40)), ("green", (60, 160, 70)),
("teal", (20, 160, 170)), ("blue", (50, 90, 200)),
("navy blue", (25, 35, 80)), ("purple", (130, 70, 190)),
("pink", (235, 120, 170)), ("beige", (220, 200, 170)),
]
def _color_name(hex_str: str) -> str:
"""Nearest plain-English name for a hex color; passes through non-hex."""
h = hex_str.strip().lstrip("#")
if len(h) == 3:
h = "".join(ch * 2 for ch in h)
if len(h) < 6 or any(ch not in "0123456789aAbBcCdDeEfF" for ch in h[:6]):
return hex_str # already a name or unparseable - keep as-is
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return min(_NAMED_COLORS,
key=lambda nc: (nc[1][0]-r)**2 + (nc[1][1]-g)**2 + (nc[1][2]-b)**2)[0]
|