pcr12's picture
Slice 3: initial deploy
ac8f59c
Raw
History Blame
6.08 kB
import base64
import io
import math
import os
import time
import uuid
import numpy as np
import torch
from fastapi import FastAPI, File, Form, HTTPException, Query, Security, UploadFile
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image
from gradcam import generate_heatmap_with_cam
from model_loader import get_model
app = FastAPI(title="Creative Intelligence Scorer")
security = HTTPBearer()
BEARER_TOKEN = os.environ["API_TOKEN"]
ERROR_MESSAGES = {
"model_not_loaded": "Model is warming up. Please retry in 30 seconds.",
"invalid_image": "Could not process image. Please upload a JPG or PNG.",
"inference_failed": "Scoring failed. Please try again.",
}
BENCHMARKS = {
"gaming": {"median_ctr": 0.119, "median_halflife": 10.7, "sample_size": 7485},
"ecommerce": {"median_ctr": 0.125, "median_halflife": 11.2, "sample_size": 7078},
"finance": {"median_ctr": 0.111, "median_halflife": 10.0, "sample_size": 6520},
"other": {"median_ctr": 0.118, "median_halflife": 10.6, "sample_size": 1000},
}
GRID_LABELS = [
"top-left", "top-center", "top-right",
"mid-left", "center", "mid-right",
"bot-left", "bot-center", "bot-right",
]
def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> None:
if credentials.credentials != BEARER_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
def _open_image(data: bytes) -> Image.Image:
try:
return Image.open(io.BytesIO(data)).convert("RGB")
except Exception:
raise HTTPException(status_code=422, detail=ERROR_MESSAGES["invalid_image"])
def _cam_region_labels(cam_16: np.ndarray) -> tuple[list[str], list[str]]:
"""3×3 grid aggregation over a 16×16 cam → top-3 high and top-3 low cells."""
h, w = cam_16.shape
cells = []
for row in range(3):
for col in range(3):
r0, r1 = row * h // 3, (row + 1) * h // 3
c0, c1 = col * w // 3, (col + 1) * w // 3
cells.append(float(cam_16[r0:r1, c0:c1].mean()))
ranked = sorted(range(9), key=lambda i: cells[i], reverse=True)
return (
[GRID_LABELS[i] for i in ranked[:3]],
[GRID_LABELS[i] for i in ranked[-3:]],
)
@app.on_event("startup")
async def startup() -> None:
# Pre-load model so first request doesn't time out
try:
get_model()
except Exception:
pass # health endpoint will report model_loaded=False; don't crash the server
@app.get("/health")
async def health() -> dict:
try:
model, _ = get_model()
loaded = model is not None
except Exception:
loaded = False
return {"status": "ok", "model_loaded": loaded}
@app.get("/benchmark")
async def benchmark(vertical: str = Query(...)) -> dict:
key = vertical.lower()
if key not in BENCHMARKS:
raise HTTPException(
status_code=422,
detail=f"Unknown vertical '{vertical}'. Valid: {list(BENCHMARKS)}",
)
return {"vertical": key, **BENCHMARKS[key]}
@app.post("/score")
async def score(
image: UploadFile = File(...),
vertical: str = Form(...),
_: None = Security(verify_token),
) -> dict:
try:
model, processor = get_model()
except Exception:
raise HTTPException(status_code=503, detail=ERROR_MESSAGES["model_not_loaded"])
try:
pil_image = _open_image(await image.read())
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=422, detail=ERROR_MESSAGES["invalid_image"])
try:
t0 = time.monotonic()
inputs = processor(images=pil_image, return_tensors="pt")
pixel_values = inputs["pixel_values"] # (1, 3, 224, 224), CPU
with torch.no_grad():
clip_out = model.clip(pixel_values=pixel_values)
embedding = clip_out.pooler_output # (1, 768)
outputs = model(embedding=embedding)
ctr_score = float(outputs["ctr_score"].squeeze())
log_scale = float(outputs["weibull_params"][0, 0].clamp(-10, 10))
log_shape = float(outputs["weibull_params"][0, 1].clamp(-10, 10))
scale = math.exp(log_scale)
shape = math.exp(log_shape)
# Weibull median = scale * ln(2)^(1/shape) — survival-theoretic halflife
halflife_days = scale * (math.log(2) ** (1.0 / shape))
# Confidence proxy: certainty peaks when prediction is near 0 or 1
confidence = float(1.0 - 2.0 * abs(ctr_score - 0.5))
inference_ms = int((time.monotonic() - t0) * 1000)
except Exception:
raise HTTPException(status_code=500, detail=ERROR_MESSAGES["inference_failed"])
return {
"ad_id": str(uuid.uuid4()),
"ctr_score": round(ctr_score, 4),
"halflife_days": round(halflife_days, 2),
"confidence": round(confidence, 4),
"inference_ms": inference_ms,
}
@app.post("/heatmap")
async def heatmap(
image: UploadFile = File(...),
_: None = Security(verify_token),
) -> dict:
try:
model, _ = get_model()
except Exception:
raise HTTPException(status_code=503, detail=ERROR_MESSAGES["model_not_loaded"])
try:
pil_image = _open_image(await image.read())
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=422, detail=ERROR_MESSAGES["invalid_image"])
try:
# Single pass: returns overlay (for base64) + raw cam_16 (for region labels)
overlay, cam_16 = generate_heatmap_with_cam(model, pil_image, device="cpu")
high_attention, low_attention = _cam_region_labels(cam_16)
buf = io.BytesIO()
Image.fromarray(overlay).save(buf, format="PNG")
heatmap_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
except Exception:
raise HTTPException(status_code=500, detail=ERROR_MESSAGES["inference_failed"])
return {
"heatmap_b64": heatmap_b64,
"high_attention": high_attention,
"low_attention": low_attention,
}