| import base64 |
| import io |
| import math |
| import os |
| import time |
| import uuid |
|
|
| import numpy as np |
| import torch |
| from fastapi import FastAPI, File, Form, HTTPException, Query, Security, UploadFile |
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
| from PIL import Image |
|
|
| from gradcam import generate_heatmap_with_cam |
| from model_loader import get_model |
|
|
| app = FastAPI(title="Creative Intelligence Scorer") |
| security = HTTPBearer() |
|
|
| BEARER_TOKEN = os.environ["API_TOKEN"] |
|
|
| ERROR_MESSAGES = { |
| "model_not_loaded": "Model is warming up. Please retry in 30 seconds.", |
| "invalid_image": "Could not process image. Please upload a JPG or PNG.", |
| "inference_failed": "Scoring failed. Please try again.", |
| } |
|
|
| BENCHMARKS = { |
| "gaming": {"median_ctr": 0.119, "median_halflife": 10.7, "sample_size": 7485}, |
| "ecommerce": {"median_ctr": 0.125, "median_halflife": 11.2, "sample_size": 7078}, |
| "finance": {"median_ctr": 0.111, "median_halflife": 10.0, "sample_size": 6520}, |
| "other": {"median_ctr": 0.118, "median_halflife": 10.6, "sample_size": 1000}, |
| } |
|
|
| GRID_LABELS = [ |
| "top-left", "top-center", "top-right", |
| "mid-left", "center", "mid-right", |
| "bot-left", "bot-center", "bot-right", |
| ] |
|
|
|
|
| def verify_token( |
| credentials: HTTPAuthorizationCredentials = Security(security), |
| ) -> None: |
| if credentials.credentials != BEARER_TOKEN: |
| raise HTTPException(status_code=401, detail="Invalid token") |
|
|
|
|
| def _open_image(data: bytes) -> Image.Image: |
| try: |
| return Image.open(io.BytesIO(data)).convert("RGB") |
| except Exception: |
| raise HTTPException(status_code=422, detail=ERROR_MESSAGES["invalid_image"]) |
|
|
|
|
| def _cam_region_labels(cam_16: np.ndarray) -> tuple[list[str], list[str]]: |
| """3×3 grid aggregation over a 16×16 cam → top-3 high and top-3 low cells.""" |
| h, w = cam_16.shape |
| cells = [] |
| for row in range(3): |
| for col in range(3): |
| r0, r1 = row * h // 3, (row + 1) * h // 3 |
| c0, c1 = col * w // 3, (col + 1) * w // 3 |
| cells.append(float(cam_16[r0:r1, c0:c1].mean())) |
|
|
| ranked = sorted(range(9), key=lambda i: cells[i], reverse=True) |
| return ( |
| [GRID_LABELS[i] for i in ranked[:3]], |
| [GRID_LABELS[i] for i in ranked[-3:]], |
| ) |
|
|
|
|
| @app.on_event("startup") |
| async def startup() -> None: |
| |
| try: |
| get_model() |
| except Exception: |
| pass |
|
|
|
|
| @app.get("/health") |
| async def health() -> dict: |
| try: |
| model, _ = get_model() |
| loaded = model is not None |
| except Exception: |
| loaded = False |
| return {"status": "ok", "model_loaded": loaded} |
|
|
|
|
| @app.get("/benchmark") |
| async def benchmark(vertical: str = Query(...)) -> dict: |
| key = vertical.lower() |
| if key not in BENCHMARKS: |
| raise HTTPException( |
| status_code=422, |
| detail=f"Unknown vertical '{vertical}'. Valid: {list(BENCHMARKS)}", |
| ) |
| return {"vertical": key, **BENCHMARKS[key]} |
|
|
|
|
| @app.post("/score") |
| async def score( |
| image: UploadFile = File(...), |
| vertical: str = Form(...), |
| _: None = Security(verify_token), |
| ) -> dict: |
| try: |
| model, processor = get_model() |
| except Exception: |
| raise HTTPException(status_code=503, detail=ERROR_MESSAGES["model_not_loaded"]) |
|
|
| try: |
| pil_image = _open_image(await image.read()) |
| except HTTPException: |
| raise |
| except Exception: |
| raise HTTPException(status_code=422, detail=ERROR_MESSAGES["invalid_image"]) |
|
|
| try: |
| t0 = time.monotonic() |
|
|
| inputs = processor(images=pil_image, return_tensors="pt") |
| pixel_values = inputs["pixel_values"] |
|
|
| with torch.no_grad(): |
| clip_out = model.clip(pixel_values=pixel_values) |
| embedding = clip_out.pooler_output |
| outputs = model(embedding=embedding) |
|
|
| ctr_score = float(outputs["ctr_score"].squeeze()) |
|
|
| log_scale = float(outputs["weibull_params"][0, 0].clamp(-10, 10)) |
| log_shape = float(outputs["weibull_params"][0, 1].clamp(-10, 10)) |
| scale = math.exp(log_scale) |
| shape = math.exp(log_shape) |
| |
| halflife_days = scale * (math.log(2) ** (1.0 / shape)) |
|
|
| |
| confidence = float(1.0 - 2.0 * abs(ctr_score - 0.5)) |
|
|
| inference_ms = int((time.monotonic() - t0) * 1000) |
|
|
| except Exception: |
| raise HTTPException(status_code=500, detail=ERROR_MESSAGES["inference_failed"]) |
|
|
| return { |
| "ad_id": str(uuid.uuid4()), |
| "ctr_score": round(ctr_score, 4), |
| "halflife_days": round(halflife_days, 2), |
| "confidence": round(confidence, 4), |
| "inference_ms": inference_ms, |
| } |
|
|
|
|
| @app.post("/heatmap") |
| async def heatmap( |
| image: UploadFile = File(...), |
| _: None = Security(verify_token), |
| ) -> dict: |
| try: |
| model, _ = get_model() |
| except Exception: |
| raise HTTPException(status_code=503, detail=ERROR_MESSAGES["model_not_loaded"]) |
|
|
| try: |
| pil_image = _open_image(await image.read()) |
| except HTTPException: |
| raise |
| except Exception: |
| raise HTTPException(status_code=422, detail=ERROR_MESSAGES["invalid_image"]) |
|
|
| try: |
| |
| overlay, cam_16 = generate_heatmap_with_cam(model, pil_image, device="cpu") |
|
|
| high_attention, low_attention = _cam_region_labels(cam_16) |
|
|
| buf = io.BytesIO() |
| Image.fromarray(overlay).save(buf, format="PNG") |
| heatmap_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
| except Exception: |
| raise HTTPException(status_code=500, detail=ERROR_MESSAGES["inference_failed"]) |
|
|
| return { |
| "heatmap_b64": heatmap_b64, |
| "high_attention": high_attention, |
| "low_attention": low_attention, |
| } |
|
|