File size: 3,665 Bytes
ac8f59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pathlib import Path
from typing import List

import cv2
import numpy as np
import torch
from PIL import Image
from transformers import CLIPProcessor

from clip_head import CreativeScorer


def _compute_cam(
    model: CreativeScorer,
    image: Image.Image,
    device: str,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Gradient-weighted feature attribution on the projection layer output.

    True GradCAM requires spatial feature maps. CLIP's pooler_output collapses
    the 197 patch tokens into a single 512-dim vector, discarding spatial structure,
    so genuine pixel-level heatmaps are not possible here. Instead we:
      - compute gradients of ctr_score w.r.t. the 256-dim shared_repr
      - weight each channel by its gradient magnitude
      - reshape the 256-dim weighted vector into a 16x16 grid (256 = 16*16)
      - upsample to 224x224 and overlay on the original image

    This approximates which projection-layer feature dimensions drive the CTR
    prediction but does not map faithfully to image regions. Treat output as an
    attribution proxy, not a spatial saliency map.

    Returns (overlay uint8 224x224x3, cam_16x16 float32 normalized 0-1).
    """
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)

    model = model.to(device)
    model.eval()
    model.zero_grad()

    # Forward outside torch.no_grad() so autograd builds the computation graph.
    # The clip block inside forward() uses its own no_grad, so the backbone stays
    # frozen while the projection layer's params remain in the graph.
    outputs = model(pixel_values)
    shared_repr = outputs["shared_repr"]  # (1, 256)
    shared_repr.retain_grad()

    outputs["ctr_score"].squeeze().backward()

    grad = shared_repr.grad                                              # (1, 256)
    weights = grad.abs().squeeze(0)                                     # (256,)
    weighted = (weights * shared_repr.detach().squeeze(0)).cpu().numpy()  # (256,)

    cam = weighted.reshape(16, 16)
    cam = cam - cam.min()
    if cam.max() > 0:
        cam = cam / cam.max()

    cam_upsampled = cv2.resize(cam.astype(np.float32), (224, 224), interpolation=cv2.INTER_LINEAR)
    cam_uint8 = (cam_upsampled * 255).astype(np.uint8)

    heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)

    orig = np.array(image.resize((224, 224))).astype(np.float32)
    overlay = np.clip(0.6 * orig + 0.4 * heatmap, 0, 255).astype(np.uint8)

    return overlay, cam.astype(np.float32)


def generate_heatmap(model: CreativeScorer, image: Image.Image, device: str = "cpu") -> np.ndarray:
    """Returns overlay uint8 ndarray (224x224x3)."""
    overlay, _ = _compute_cam(model, image, device)
    return overlay


def generate_heatmap_with_cam(
    model: CreativeScorer,
    image: Image.Image,
    device: str = "cpu",
) -> tuple[np.ndarray, np.ndarray]:
    """Returns (overlay uint8 224x224x3, cam_16x16 float32 normalized 0-1)."""
    return _compute_cam(model, image, device)


def save_heatmaps(
    model: CreativeScorer,
    image_paths: List[str],
    output_dir: str,
    device: str = "cpu",
) -> None:
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    for path in image_paths:
        image = Image.open(path).convert("RGB")
        overlay = generate_heatmap(model, image, device)
        stem = Path(path).stem
        Image.fromarray(overlay).save(out / f"{stem}_heatmap.png")
        print(f"Saved: {stem}_heatmap.png")