AIDetect / miragenews /test_single_pair.py
Lucii1's picture
refactor code
ecccf5c
raw
history blame contribute delete
5.35 kB
import torch
from models import get_model
from data import load_config
from utils import load_model_checkpoint
import os
# --- 1. LOAD CONFIGS AND INITIALIZE MODEL (Run once) ---
print("Loading configs and initializing models ONCE...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
config_img = load_config("configs/image/mirage.yaml")
config_multi = load_config("configs/multimodal/mirage.yaml")
mirage_img = get_model(config_img).to(device)
# --- 2. Load Model Checkpoint (Run once) ---
print("Loading model checkpoint ONCE...")
checkpoint_path_img = config_multi['training']['image_model_path']
if not os.path.exists(checkpoint_path_img):
print(f"FATAL ERROR: Model checkpoint not found at {checkpoint_path_img}")
exit()
mirage_img, _ = load_model_checkpoint(mirage_img, checkpoint_path_img)
print(f"Loaded image model from {checkpoint_path_img}")
# --- 3. Set to Evaluation Mode (Run once) ---
mirage_img.eval()
print("Model set to evaluation mode.")
except Exception as e:
print(f"FATAL ERROR during model initialization or checkpoint loading: {e}")
exit()
# --- 4. PREDICTION FUNCTION FROM .PT FILE ---
def predict_authenticity_from_pt(pt_file_path, model, device):
"""
Load encoding data from a .pt file and predict authenticity (real/fake).
Args:
pt_file_path (str): Path to the .pt file containing the [1, 301] tensor.
model (torch.nn.Module): The image model with the checkpoint loaded and in eval mode.
device (str): The device, 'cuda' or 'cpu'.
Returns:
tuple: (probability_fake, prediction_label)
- probability_fake (float): The probability that the image is FAKE (0.0 to 1.0).
- prediction_label (str): The predicted label ("real" or "fake").
Returns (None, None) on error.
"""
print(f"\n--- Processing: {pt_file_path} ---")
# 4.1. Check if file exists
if not os.path.exists(pt_file_path):
print(f"ERROR: Input file not found at {pt_file_path}")
return None, None
# 4.2. Load encoding data
try:
image_encodings = torch.load(pt_file_path).to(device)
# Check shape (must be [1, 301])
if image_encodings.shape != (1, 301):
print(f"ERROR: Expected tensor shape [1, 301], but got {image_encodings.shape} from {pt_file_path}")
return None, None
print(f"Loaded image data (301-dim) with shape: {image_encodings.shape}")
except Exception as e:
print(f"ERROR loading or checking data from {pt_file_path}: {e}")
return None, None
# 4.3. Run Inference (Prediction)
try:
with torch.no_grad():
output_logits_img = model(image_encodings) # Model ảnh nhận [1, 301]
except Exception as e:
print(f"ERROR during model inference for {pt_file_path}: {e}")
return None, None
# 4.4. Process results (Logic: 1 = fake)
probs_img = torch.sigmoid(output_logits_img)
probability_fake = probs_img.squeeze().item() # Probability the image is FAKE
# Decide label based on a 0.5 threshold (Inverted logic)
# If "fake" probability >= 0.5 -> fake
# If "fake" probability < 0.5 -> real
prediction_label = "fake" if probability_fake >= 0.5 else "real"
print(f"Raw output (logits interpreted as 'fake' logit): {output_logits_img.squeeze().item():.4f}")
# Update explanatory string
print(f"Probability (0=real, 1=fake): {probability_fake:.4f}")
print(f"Predicted Label: {prediction_label}")
# Return the fake probability and the label
return probability_fake, prediction_label
# --- 5. HOW TO USE THE FUNCTION ---
if __name__ == "__main__":
# --- PROCESS A SINGLE FILE ---
# Replace with the path to the .pt file you want to check
input_pt_path_single = "encodings/predictions/image/merged/my_single_image_dir/real.pt"
print("\n--- Processing a single file ---")
prob_fake_single, label_single = predict_authenticity_from_pt(input_pt_path_single, mirage_img, device)
if prob_fake_single is not None and label_single is not None:
print(f"\nFinal result for {input_pt_path_single}: Probability Fake={prob_fake_single:.4f}, Label='{label_single}'")
else:
print(f"\nFailed to process {input_pt_path_single}.")
print("\n" + "="*50 + "\n") # Add a divider line
# --- EXAMPLE: PROCESS MULTIPLE FILES ---
pt_files_to_check = [
"encodings/predictions/image/merged/my_single_image_dir/real.pt", # Replace with the real file path
# "encodings/predictions/image/merged/another_dir/fake_image.pt", # THIS LINE WAS REMOVED
"path/to/nonexistent.pt" # Example of a missing file
]
print("\n--- Processing multiple files ---")
results = {}
for file_path in pt_files_to_check:
prob_fake, label = predict_authenticity_from_pt(file_path, mirage_img, device)
results[file_path] = (prob_fake, label) # Store results in a dictionary
print("\n--- Summary ---")
for file, (prob_fake, label) in results.items():
if prob_fake is not None:
print(f"{file}: Prob Fake={prob_fake:.4f}, Label='{label}'")
else:
print(f"{file}: Processing FAILED")
print("\nScript finished.")