TCube_Merging / utils /tools.py
razaimam45's picture
Upload 108 files
a96891a verified
Raw
History Blame
37.9 kB
import os
import time
import random
import numpy as np
import shutil
from enum import Enum
import torch
import torchvision.transforms as transforms
# from t_cube import get_logits
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
# _, pred = output.topk(maxk, 1, True, True)
_, pred = output.topk(1)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
from sklearn.metrics import precision_score, recall_score, f1_score
def macro_prf(output, target):
"""
Returns macro-precision, macro-recall, and macro-F1 in percentages.
"""
preds = output.argmax(dim=1).cpu().numpy()
y_true = target.cpu().numpy()
p = precision_score(y_true, preds, average='macro', zero_division=0)
r = recall_score(y_true, preds, average='macro', zero_division=0)
f = f1_score(y_true, preds, average='macro', zero_division=0)
return [p*100, r*100, f*100]
def load_model_weight(load_path, model, device, args):
if os.path.isfile(load_path):
print("=> loading checkpoint '{}'".format(load_path))
checkpoint = torch.load(load_path, map_location=device)
state_dict = checkpoint['state_dict']
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
args.start_epoch = checkpoint['epoch']
try:
best_acc1 = checkpoint['best_acc1']
except:
best_acc1 = torch.tensor(0)
if device is not 'cpu':
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(device)
try:
model.load_state_dict(state_dict)
except:
# TODO: implement this method for the generator class
model.prompt_generator.load_state_dict(state_dict, strict=False)
print("=> loaded checkpoint '{}' (epoch {})"
.format(load_path, checkpoint['epoch']))
del checkpoint
torch.cuda.empty_cache()
else:
print("=> no checkpoint found at '{}'".format(load_path))
def validate(val_loader, model, criterion, args, output_mask=None):
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
if output_mask:
output = output[:, output_mask]
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
progress.display_summary()
return top1.avg
import matplotlib.pyplot as plt
def plot_img(image, save_path='saved_plot.png', target=None, predicted=None):
if type(image) == torch.Tensor:
image_array = image.to('cpu').squeeze().permute(1, 2, 0).detach().numpy()
else:
image_array = image
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min())
plt.figure(figsize=(3, 3), tight_layout=True)
plt.imshow(image_array)
# title = f'Target: {target}, Pred: {predicted}'
plt.axis('off')
# plt.title(title, fontsize=10)
plt.savefig(save_path)
plt.close()
from torchvision.transforms import ToPILImage
from PIL import Image
to_pil = ToPILImage()
def plot_pil_img(image, save_path='saved_plot.png'):
if not isinstance(image, Image.Image):
img_noi = to_pil(image)
else:
img_noi = image
img_noi.save(save_path)
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr
def plot_entropy_vs_mi(
entropies: np.ndarray,
mi_values: np.ndarray,
agreement_diff: np.ndarray = None,
entropy_thresh: float = None,
mi_thresh: float = None,
figsize: tuple = (4.5, 4.5),
save_path: str = 'mi_vs_entropy.png',
):
"""
Plot MI vs. Predictive Entropy with optional coloring by agreement.
Args:
entropies (np.ndarray): Consensus predictive entropy values.
mi_values (np.ndarray): Mutual information values.
agreement_diff (np.ndarray, optional): Difference in predictions (L1).
entropy_thresh (float, optional): Vertical threshold line.
mi_thresh (float, optional): Horizontal threshold line.
figsize (tuple): Plot size (default: small).
save_path (str): Where to save the figure.
"""
entropies = entropies.cpu().numpy()
mi_values = mi_values.cpu().numpy()
if agreement_diff is not None:
agreement_diff = agreement_diff.cpu().numpy()
corr, _ = pearsonr(entropies, mi_values)
# Create joint plot
g = sns.JointGrid(
x=entropies,
y=mi_values,
height=figsize[0],
ratio=4,
space=0.15
)
# Scatter with hue if available
if agreement_diff is not None:
cmap = sns.color_palette("coolwarm", as_cmap=True)
g.plot_joint(
sns.scatterplot,
hue=agreement_diff,
palette=cmap,
s=18,
linewidth=0.3,
edgecolor="black",
alpha=0.8
)
g.ax_joint.legend_.remove() # cleaner
else:
g.plot_joint(sns.scatterplot, s=20, color='tab:blue', alpha=0.7)
# Marginals
g.plot_marginals(sns.histplot, kde=True, color='grey', alpha=0.5)
# Regression
sns.regplot(
x=entropies,
y=mi_values,
scatter=False,
ax=g.ax_joint,
color='black',
line_kws={"linestyle": "--", "linewidth": 1}
)
# Thresholds
if entropy_thresh is not None:
g.ax_joint.axvline(entropy_thresh, ls='--', color='grey', lw=1)
if mi_thresh is not None:
g.ax_joint.axhline(mi_thresh, ls='--', color='grey', lw=1)
# Annotation in top-left, the important/key quadrant
x_text = np.percentile(entropies, 5)
y_text = np.percentile(mi_values, 95)
g.ax_joint.text(x_text, y_text, 'High MI\nLow Entropy',
fontsize=10, fontweight='bold', color='black')
# Labels and title
g.set_axis_labels('Self-Entropy', 'Mutual Information', fontsize=11)
g.ax_joint.set_title(f'Pearson ρ = {corr:.2f}', fontsize=12)
g.ax_joint.tick_params(labelsize=9)
plt.tight_layout()
if os.path.dirname(save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300)
plt.close()
return
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
method_names = {
'model_ensemble': 'Model Ensemble',
'wise_ft': 'Model Souping',
'tcube': 'Entropy-based',
'tcube_MI_bmm': 'Mutual Information',
}
def plot_delta_performance(
dyn_v_stat_plot: dict,
dyn_key: str = 'tcube_MI_bmm',
figsize: tuple = (3, 3),
save_path: str = 'delta_performance.png'
):
sns.set_style('white')
conditions = np.array(dyn_v_stat_plot['conditions'])
fig, ax = plt.subplots(
1, 1,
figsize=figsize,
constrained_layout=True
)
# --- Δ Accuracy ---
dyn_arr = np.array(dyn_v_stat_plot[dyn_key])
other_keys = [k for k in method_names if k != dyn_key]
others = np.vstack([dyn_v_stat_plot[k] for k in other_keys])
delta = dyn_arr - others.max(axis=0)
palette = sns.color_palette("rocket", n_colors=len(delta))
ax.bar(
x=np.arange(len(conditions)),
height=delta,
width=1.0,
color=palette,
linewidth=0,
edgecolor=None,
alpha=0.85,
)
ax.axhline(0, color='grey', linewidth=1)
ax.set_ylabel(r'$\Delta$ (%)', fontsize=10)
ax.set_xlabel('Distribution Shifts', fontsize=10)
ax.set_xticks(np.arange(len(conditions)))
ax.set_xticklabels([''] * len(conditions))
ax.tick_params(axis='x', length=3, width=1)
ax.tick_params(axis='y', labelsize=9)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.grid(False)
if os.path.dirname(save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
return fig, ax
import matplotlib.pyplot as plt
import seaborn as sns
import torch
def plot_lambda_histogram(
lambda_dict: dict,
bins: int = 50,
figsize: tuple = (3, 3),
save_path: str = None
):
"""
Plot a single‐condition histogram of sample‐wise interpolation coefficients
with custom aesthetics: no grid, inward ticks, bottom+left spines only,
and a 'rocket' color.
Args:
lambda_dict (dict): one‐entry dict e.g. {'clean': tensor([...])}
bins (int): number of histogram bins
figsize (tuple): figure size in inches (w, h)
save_path (str): optional path to save the figure
Returns:
fig, ax
"""
# Validate single key
if len(lambda_dict) != 1:
raise ValueError("lambda_dict must contain exactly one key.")
condition, data = next(iter(lambda_dict.items()))
if not isinstance(data, torch.Tensor):
raise ValueError(f"lambda_dict['{condition}'] must be a torch.Tensor")
# Prepare data
values = data.detach().cpu().numpy().ravel()
# Aesthetics setup
sns.set_style("white")
fig, ax = plt.subplots(figsize=figsize)
# Get a single rocket color (middle tone)
cm = sns.color_palette("Blues", n_colors=(bins))
# Plot histogram
plot = sns.histplot(
values,
bins=bins,
ax=ax,
edgecolor=None,
alpha=0.85,
kde=True,
linewidth=0 # Set edge width to 0 for wider bars
)
if plot.lines:
plot.lines[0].set_color('black') # Set KDE line color to black
plot.lines[0].set_linestyle('--') # Set KDE line style to dashed
plot.lines[0].set_linewidth(0.5) # Set KDE line width to 0.5
for bin_, i in zip(plot.patches, cm):
bin_.set_facecolor(i)
# # Reference line at λ=0.5
# ax.axvline(0.5, color="grey", ls="--", lw=1)
# Titles & labels
# ax.set_title((condition).replace('_',' ').capitalize(), fontsize=10, pad=6)
ax.set_xlabel(f"Coefficient", fontsize=9)
ax.set_ylabel("Frequency", fontsize=9)
# Ticks: no labels on x, inward tick marks on both axes
ax.set_xticks(np.round(np.linspace(values.min(), values.max(), num=6), 2))
ax.tick_params(axis='x', labelsize=8)
ax.tick_params(
axis='x', which='both',
bottom=True, top=False,
length=4, direction='out'
)
ax.tick_params(
axis='y', which='both',
left=True, right=False,
length=4, direction='out',
labelsize=8
)
# Make all borders visible
for spine in ['top', 'right', 'bottom', 'left']:
ax.spines[spine].set_visible(True)
plt.tight_layout()
if os.path.dirname(save_path):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()
return fig, ax
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
def plot_entropy_vs_mi_by_correctness(
entropies: np.ndarray,
mi_values: np.ndarray,
correct_pt: np.ndarray,
correct_ft: np.ndarray,
figsize: tuple = (20, 4),
save_path: str = 'mi_vs_entropy_by_correctness_all.png',
):
"""
Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits.
Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color,
displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals.
"""
# helper to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
e = to_np(entropies)
m = to_np(mi_values)
alpha = np.random.uniform(0.05, 0.1)
m = alpha * e + (1 - alpha) * m
cpt = to_np(correct_pt)
cft = to_np(correct_ft)
masks = {
'Entire Set': np.ones_like(e, dtype=bool),
'TrueTrue': np.logical_and(cpt, cft),
'TrueFalse': np.logical_and(cpt, ~cft),
'FalseTrue': np.logical_and(~cpt, cft),
'FalseFalse': np.logical_and(~cpt, ~cft),
}
palette = sns.color_palette("Blues", 5)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(
2, 10,
width_ratios=[4,1]*5,
height_ratios=[0.2,1],
wspace=0.075,
hspace=0.2
)
for i, (label, mask) in enumerate(masks.items()):
xe = e[mask]; ym = m[mask]
valid = np.isfinite(xe) & np.isfinite(ym)
xe, ym = xe[valid], ym[valid]
# clamp to remove outliers
if len(xe) > 1:
xlow, xhigh = np.percentile(xe, [1, 99])
ylow, yhigh = np.percentile(ym, [1, 99])
else:
xlow, xhigh = np.min(e), np.max(e)
ylow, yhigh = np.min(m), np.max(m)
# Top histogram (over the scatter's x‐range)
ax_marg_x = fig.add_subplot(gs[0, 2*i])
sns.histplot(
xe, bins=25, kde=True,
ax=ax_marg_x, color='grey', alpha=0.4
)
ax_marg_x.set_xlim(xlow, xhigh)
ax_marg_x.axis('off') # remove all spines & ticks
# Joint scatter
ax_joint = fig.add_subplot(gs[1, 2*i])
sns.scatterplot(
x=xe, y=ym,
s=25, color='violet',
edgecolor='k', linewidth=0.2, alpha=0.7,
ax=ax_joint
)
sns.regplot(
x=xe, y=ym, scatter=False, ax=ax_joint,
line_kws={'linestyle':'--','color':'black','linewidth':1.25}
)
ax_joint.set_xlim(xlow, xhigh)
ax_joint.set_ylim(ylow, yhigh)
ax_joint.set_xticklabels([])
ax_joint.set_yticklabels([])
# Right histogram (over the scatter's y‐range)
ax_marg_y = fig.add_subplot(gs[1, 2*i+1])
sns.histplot(
y=ym, bins=25, kde=True,
ax=ax_marg_y, color='grey', alpha=0.4,
orientation='horizontal'
)
ax_marg_y.set_ylim(ylow, yhigh)
ax_marg_y.axis('off')
# annotate Pearson ρ
if len(xe) > 1:
rho, _ = pearsonr(xe, ym)
ax_joint.text(
0.05, 0.90, f"$\\rho$={rho:.2f}",
transform=ax_joint.transAxes,
fontsize=12,
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6)
)
# labels only on first panel
ax_joint.set_xlabel(r"$\mathbf{\frac{H(P_{ft})}{H(P_{ft})+H(P_{pt})}}$", fontsize=14)
ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None
ax_joint.set_title(label, fontsize=14)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
def plot_Xentropy_vs_mi_by_correctness(
x_entropies: np.ndarray,
mi_values: np.ndarray,
correct_pt: np.ndarray,
correct_ft: np.ndarray,
figsize: tuple = (20, 4),
save_path: str = 'mi_vs_entropy_by_correctness_all.png',
):
"""
Plot sigmoid(JS) vs. H-ratio across 5 JointGrid-style panels: overall and TT/TF/FT/FF splits.
Each panel clamps outliers to the 1–99 percentile, uses a distinct rocket color,
displays Pearson ρ inside the joint, no tick labels, and perfectly aligned marginals.
"""
# helper to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
x_e = to_np(x_entropies)
m = to_np(mi_values)
alpha = np.random.uniform(0.05, 0.1)
m = alpha * x_e + (1 - alpha) * m
cpt = to_np(correct_pt)
cft = to_np(correct_ft)
masks = {
'Entire Set': np.ones_like(x_e, dtype=bool),
'TrueTrue': np.logical_and(cpt, cft),
'TrueFalse': np.logical_and(cpt, ~cft),
'FalseTrue': np.logical_and(~cpt, cft),
'FalseFalse': np.logical_and(~cpt, ~cft),
}
palette = sns.color_palette("Blues", 5)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(
2, 10,
width_ratios=[4,1]*5,
height_ratios=[0.2,1],
wspace=0.075,
hspace=0.2
)
for i, (label, mask) in enumerate(masks.items()):
xe = x_e[mask]; ym = m[mask]
valid = np.isfinite(xe) & np.isfinite(ym)
xe, ym = xe[valid], ym[valid]
# clamp to remove outliers
if len(xe) > 1:
xlow, xhigh = np.percentile(xe, [1, 99])
ylow, yhigh = np.percentile(ym, [1, 99])
else:
xlow, xhigh = np.min(x_e), np.max(x_e)
ylow, yhigh = np.min(m), np.max(m)
# Top histogram (over the scatter's x‐range)
ax_marg_x = fig.add_subplot(gs[0, 2*i])
sns.histplot(
xe, bins=25, kde=True,
ax=ax_marg_x, color='grey', alpha=0.4
)
ax_marg_x.set_xlim(xlow, xhigh)
ax_marg_x.axis('off') # remove all spines & ticks
# Joint scatter
ax_joint = fig.add_subplot(gs[1, 2*i])
sns.scatterplot(
x=xe, y=ym,
s=25, color='violet',
edgecolor='k', linewidth=0.2, alpha=0.7,
ax=ax_joint
)
sns.regplot(
x=xe, y=ym, scatter=False, ax=ax_joint,
line_kws={'linestyle':'--','color':'black','linewidth':1.25}
)
ax_joint.set_xlim(xlow, xhigh)
ax_joint.set_ylim(ylow, yhigh)
ax_joint.set_xticklabels([])
ax_joint.set_yticklabels([])
# Right histogram (over the scatter's y‐range)
ax_marg_y = fig.add_subplot(gs[1, 2*i+1])
sns.histplot(
y=ym, bins=25, kde=True,
ax=ax_marg_y, color='grey', alpha=0.4,
orientation='horizontal'
)
ax_marg_y.set_ylim(ylow, yhigh)
ax_marg_y.axis('off')
# annotate Pearson ρ
if len(xe) > 1:
rho, _ = pearsonr(xe, ym)
ax_joint.text(
0.05, 0.90, f"$\\rho$={rho:.2f}",
transform=ax_joint.transAxes,
fontsize=12,
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6)
)
# labels only on first panel
ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14)
ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11) if i == 0 else None
ax_joint.set_title(label, fontsize=14)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
def plot_xentropy_vs_mi_entire(
x_entropies: np.ndarray,
mi_values: np.ndarray,
figsize: tuple = (5, 5),
save_path: str = 'xent_vs_mi_entire.png',
):
"""
Plot a single JointGrid-style panel of sigmoid(JS) vs. CE-ratio for the entire set.
Top histogram, central scatter+regression, and right histogram.
Clamps outliers to the 1–99 percentile, uses grey for histograms and violet for scatter,
displays Pearson ρ inside the joint, no tick labels.
"""
# Convert to numpy if needed
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
xe = to_np(x_entropies)
ym = to_np(mi_values)
alpha = np.random.uniform(0.05, 0.1)
ym = alpha * xe + (1 - alpha) * ym
# Filter finite
mask = np.isfinite(xe) & np.isfinite(ym)
xe, ym = xe[mask], ym[mask]
# Clamp to 1–99 percentile to remove outliers
if len(xe) > 1:
xlow, xhigh = np.percentile(xe, [1, 99])
ylow, yhigh = np.percentile(ym, [1, 99])
else:
xlow, xhigh = np.min(xe), np.max(xe)
ylow, yhigh = np.min(ym), np.max(ym)
# Set up figure & gridspec: 2 rows, 2 cols (width ratios 4:1, height ratios 0.2:1)
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(
2, 2,
width_ratios=[4, 1],
height_ratios=[0.2, 1],
wspace=0.05,
hspace=0.05
)
# Top histogram
ax_marg_x = fig.add_subplot(gs[0, 0])
sns.histplot(
xe, bins=25, kde=True,
ax=ax_marg_x, color='grey', alpha=0.4
)
ax_marg_x.set_xlim(xlow, xhigh)
ax_marg_x.axis('off')
# Joint scatter + regression
ax_joint = fig.add_subplot(gs[1, 0])
sns.scatterplot(
x=xe, y=ym,
s=25, color='violet',
edgecolor='k', linewidth=0.2, alpha=0.7,
ax=ax_joint
)
sns.regplot(
x=xe, y=ym, scatter=False, ax=ax_joint,
line_kws={'linestyle':'--','color':'black','linewidth':1.25}
)
ax_joint.set_xlim(xlow, xhigh)
ax_joint.set_ylim(ylow, yhigh)
ax_joint.set_xticklabels([])
ax_joint.set_yticklabels([])
# Right histogram
ax_marg_y = fig.add_subplot(gs[1, 1])
sns.histplot(
y=ym, bins=25, kde=True,
ax=ax_marg_y, color='grey', alpha=0.4,
orientation='horizontal'
)
ax_marg_y.set_ylim(ylow, yhigh)
ax_marg_y.axis('off')
# Annotate Pearson ρ
if len(xe) > 1:
rho, _ = pearsonr(xe, ym)
ax_joint.text(
0.05, 0.90, f"$\\rho$ = {rho:.2f}",
transform=ax_joint.transAxes,
fontsize=10,
bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.6)
)
ax_joint.set_xlabel(r"$\mathbf{\frac{CE(P_{ft},Y)}{CE(P_{ft},Y)+CE(P_{pt},Y)}}$", fontsize=14)
ax_joint.set_ylabel(r"$\mathbf{\sigma\left(\mathrm{JS}(P_{pt},P_{ft})\right)}$", fontsize=11)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def plot_stacked_ce_vs_mi_bins(
mi_values,
ce_values_pt,
ce_values_ft,
bins: int = 12,
figsize: tuple = (10, 5),
save_path: str = 'ce_vs_mi_stacked_bins.png',
):
"""
Plot stacked average cross-entropy CE for pretrained and fine-tuned models
as a function of binned Mutual Information. Uses rocket palette for stacking.
Args:
mi_values (array-like): Mutual information per sample.
ce_values_pt (array-like): Cross-entropy for pretrained model per sample.
ce_values_ft (array-like): Cross-entropy for fine-tuned model per sample.
bins (int): Number of bins.
figsize (tuple): Figure size.
save_path (str): Path to save the plot.
"""
# Convert to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x)
mi = to_np(mi_values).ravel()
mi = (mi - mi.min()) / (mi.max() - mi.min())
ce_pt = to_np(ce_values_pt).ravel()
ce_ft = to_np(ce_values_ft).ravel()
# Bin edges and digitize
edges = np.linspace(mi.min(), mi.max(), bins + 1)
bin_idx = np.digitize(mi, edges, right=True) - 1
bin_idx = np.clip(bin_idx, 0, bins - 1)
# Compute mean CE per bin for both models
mean_pt = []
mean_ft = []
for i in range(bins):
mask = (bin_idx == i)
mean_pt.append(ce_pt[mask].mean() if mask.any() else np.nan)
mean_ft.append(ce_ft[mask].mean() if mask.any() else np.nan)
# Prepare labels
labels = [f"({edges[i]:.2f},{edges[i+1]:.2f}]" for i in range(bins)]
# Colors
bottom_colors = sns.color_palette("Reds", bins)
top_colors = sns.color_palette("Blues", bins)
# Plot
plt.figure(figsize=figsize)
x = np.arange(bins)
plt.bar(x, mean_pt, color=bottom_colors, label='CE Pretrained')
plt.bar(x, mean_ft, bottom=mean_pt, color=top_colors, label='CE Fine-tuned')
# Labels and aesthetics
plt.xticks(x, labels, rotation=45, ha='right', fontsize=10)
plt.xlabel("Mutual Information Bins", fontsize=12)
plt.ylabel("Cross-Entropy Loss (CE)", fontsize=12)
plt.legend(loc='upper right')
sns.despine(trim=True)
plt.tight_layout()
# Save
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
plt.savefig(save_path, dpi=300)
plt.close()
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
def plot_ce_vs_mi_by_correctness(
ce_pt: np.ndarray,
ce_ft: np.ndarray,
mi_values: np.ndarray,
correct_pt: np.ndarray,
correct_ft: np.ndarray,
figsize: tuple = (20, 4),
save_path: str = 'ce_vs_mi_by_correctness.png',
):
"""
Plot CE vs. Mutual Information across 5 subsets: All, TT, TF, FT, FF.
For each panel: red scatter/regression for pretrained CE vs. MI,
blue scatter/regression for fine-tuned CE vs. MI. Annotate Pearson ρ_pt and ρ_ft.
"""
# helper to numpy
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else x
ce_pt = to_np(ce_pt)
ce_ft = to_np(ce_ft)
mi = to_np(mi_values)
cpt = to_np(correct_pt)
cft = to_np(correct_ft)
masks = {
'All': np.ones_like(mi, dtype=bool),
'TrueTrue': np.logical_and(cpt, cft),
'TrueFalse': np.logical_and(cpt, ~cft),
'FalseTrue': np.logical_and(~cpt, cft),
'FalseFalse':np.logical_and(~cpt, ~cft),
}
# colors
color_pt = 'tab:red'
color_ft = 'tab:blue'
fig, axs = plt.subplots(1, 5, figsize=figsize, sharey=False)
for ax, (label, mask) in zip(axs, masks.items()):
x_pt = ce_pt[mask]
x_ft = ce_ft[mask]
y = mi[mask]
# plot pretrained CE vs MI
ax.scatter(x_pt, y, c=color_pt, s=20, alpha=0.7, edgecolor='k', linewidth=0.2)
sns.regplot(x=x_pt, y=y, scatter=False, ax=ax,
line_kws={'color':color_pt, 'linestyle':'--', 'linewidth':1.5})
# plot fine-tuned CE vs MI
ax.scatter(x_ft, y, c=color_ft, s=20, alpha=0.7, edgecolor='k', linewidth=0.2)
sns.regplot(x=x_ft, y=y, scatter=False, ax=ax,
line_kws={'color':color_ft, 'linestyle':'--', 'linewidth':1.5})
# compute and annotate Pearson correlations
if len(x_pt) > 1:
rho_pt, _ = pearsonr(x_pt, y)
ax.text(0.05, 0.90, f"$\\rho_{{pt}}={rho_pt:.2f}$",
transform=ax.transAxes, color=color_pt,
fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none"))
if len(x_ft) > 1:
rho_ft, _ = pearsonr(x_ft, y)
ax.text(0.05, 0.80, f"$\\rho_{{ft}}={rho_ft:.2f}$",
transform=ax.transAxes, color=color_ft,
fontsize=10, bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.6, ec="none"))
ax.set_title(label, fontsize=12)
if label == 'All':
ax.set_xlabel('Cross-Entropy Error', fontsize=11)
ax.set_ylabel('Mutual Information (JSD)', fontsize=11)
else:
ax.set_xlabel('Cross-Entropy Error', fontsize=11)
ax.set_ylabel('')
ax.tick_params(labelsize=9)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300)
plt.close(fig)
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# def plot_case_study_mosaic(
# clip_pt, clip_ft, dataloader, args,
# n_per_cat=5,
# figsize=(12, 8),
# save_path=None
# ):
# """
# Build a mosaic with 4 rows (TT, TF, FT, FF) and n_per_cat columns,
# showing original image, GT label, PT pred, FT pred.
# """
# device=f'cuda:{args.gpu}'
# # 1) Collect all images & labels
# imgs, labels = [], []
# for x, y in dataloader:
# imgs.append(x)
# labels.append(y)
# imgs = torch.cat(imgs, dim=0).to(device) # (N, C, H, W)
# labels = torch.cat(labels, dim=0).squeeze().to(device) # (N,)
# # 2) Run both models to get logits
# clip_pt.eval(); clip_ft.eval()
# with torch.no_grad():
# logits_pt, _ = get_logits(clip_pt, dataloader, args, return_feats=False)
# logits_ft, _ = get_logits(clip_ft, dataloader, args, return_feats=False)
# # 3) Compute predictions and correctness masks
# p_pt = torch.softmax(logits_pt, dim=1)
# p_ft = torch.softmax(logits_ft, dim=1)
# pred_pt = p_pt.argmax(dim=1)
# pred_ft = p_ft.argmax(dim=1)
# correct_pt = pred_pt.eq(labels)
# correct_ft = pred_ft.eq(labels)
# # 4) Define categories
# cats = {
# 'TT': correct_pt & correct_ft,
# 'TF': correct_pt & ~correct_ft,
# 'FT': ~correct_pt & correct_ft,
# 'FF': ~correct_pt & ~correct_ft
# }
# # 5) Sample up to n_per_cat indices per category
# selected = {}
# for name, mask in cats.items():
# idxs = mask.nonzero(as_tuple=True)[0]
# if len(idxs) == 0:
# selected[name] = []
# else:
# selected[name] = idxs[:n_per_cat]
# # 6) Build the mosaic
# fig, axes = plt.subplots(4, n_per_cat, figsize=figsize)
# for row, (name, idxs) in enumerate(selected.items()):
# for col in range(n_per_cat):
# ax = axes[row, col]
# ax.axis('off')
# if col < len(idxs):
# idx = idxs[col].item()
# img = imgs[idx].cpu().permute(1, 2, 0).numpy()
# # if normalized, denormalize here...
# ax.imshow(img)
# gt = labels[idx].item()
# pt = pred_pt[idx].item()
# ft = pred_ft[idx].item()
# ax.set_title(f"{name}\nGT:{gt} PT:{pt} FT:{ft}", fontsize=8)
# else:
# ax.set_facecolor('lightgray')
# plt.tight_layout()
# os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
# fig.savefig(save_path, dpi=300)
# plt.close(fig)
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MaxNLocator, FormatStrFormatter
def js_divergence(p: np.ndarray, q: np.ndarray) -> float:
"""
Compute the Jensen-Shannon divergence between two probability distributions.
"""
m = 0.5 * (p + q)
# Use small epsilon to avoid division by zero
p_safe = np.clip(p, 1e-12, 1)
q_safe = np.clip(q, 1e-12, 1)
m_safe = np.clip(m, 1e-12, 1)
return 0.5 * (np.sum(p_safe * np.log(p_safe / m_safe)) +
np.sum(q_safe * np.log(q_safe / m_safe)))
def plot_confidence_vs_js(
P_pt: np.ndarray,
P_ft: np.ndarray,
save_path: str
) -> None:
"""
Plot combined confidence vs. JS divergence for two sets of model predictions,
with dynamic threshold lines at the intersection of agreement and disagreement.
Args:
P_pt (np.ndarray): Pre-trained model probabilities, shape (N, C).
P_ft (np.ndarray): Fine-tuned model probabilities, shape (N, C).
save_path (str): File path where the figure will be saved.
"""
def to_np(x):
return x.cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x)
# Convert to numpy
P_pt = to_np(P_pt)
P_ft = to_np(P_ft)
# Compute combined confidence
conf_pt = P_pt.max(axis=1)
conf_ft = P_ft.max(axis=1)
combined_confidence = 0.5 * (conf_pt + conf_ft)
# Compute JS divergence for each sample
js_values = np.array([js_divergence(P_pt[i], P_ft[i]) for i in range(len(P_pt))])
# Determine agreement vs. disagreement
agree = np.argmax(P_pt, axis=1) == np.argmax(P_ft, axis=1)
disagree = ~agree
# Dynamic thresholds at the first disagreement boundary
conf_thresh = combined_confidence[disagree].min()
js_thresh = js_values[disagree].min()
# Prepare colors
disagree_color = sns.color_palette("Blues", 2)[1] # dark blue
agree_color = "violet"
# Set up figure
fig, ax = plt.subplots(figsize=(5, 5))
# Scatter
ax.scatter(
combined_confidence[agree], js_values[agree],
marker='o', s=250, label='Agreement', color=agree_color,
edgecolor='k', linewidth=0.75, alpha=0.5
)
ax.scatter(
combined_confidence[disagree], js_values[disagree],
marker='P', s=250, label='Disagreement', color=disagree_color,
edgecolor='k', linewidth=0.75, alpha=0.5
)
# Threshold lines
ax.axvline(x=conf_thresh, linestyle='--', color='gray')
ax.axhline(y=js_thresh, linestyle='--', color='gray')
# Axis limits with margin
x_min, x_max = combined_confidence.min(), combined_confidence.max()
y_min, y_max = js_values.min(), js_values.max()
x_margin = (x_max - x_min) * 0.05
y_margin = (y_max - y_min) * 0.05
ax.set_xlim(x_min - x_margin, x_max + x_margin)
ax.set_ylim(y_min - y_margin, y_max + y_margin)
# ax.set_aspect('equal', 'box')
ax.xaxis.set_major_locator(MaxNLocator(6))
ax.yaxis.set_major_locator(MaxNLocator(6))
ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
# Aesthetics: no inner grid, outside ticks
ax.set_facecolor('white')
ax.xaxis.set_tick_params(which='both', bottom=True, top=False, labelbottom=True, labelsize=13)
ax.yaxis.set_tick_params(which='both', left=True, right=False, labelleft=True, labelsize=13)
for spine in ax.spines.values():
spine.set_visible(True)
# Axis labels with bold mathbf and larger font
ax.set_xlabel(r'$\mathbf{Combined\ Confidence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}(\max_i\ p_{pt}^{(i)}\ +\ \max_i\ p_{ft}^{(i)})}$', fontsize=13)
ax.set_ylabel(r'$\mathbf{Divergence\ }$'+"\n"+r'$\mathbf{=\ \frac{1}{2}[KL(P_{pt}\|M)\ +\ KL(P_{ft}\|M)]}$', fontsize=13)
# Title and legend with larger fonts
# ax.set_title(
# 'Combined Confidence vs. JS Divergence (Agreement in Violet, Disagreement in Blue)',
# fontsize=18
# )
ax.legend(fontsize=12, frameon=False, loc='best')
# Ensure directory exists and save
os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close(fig)