import torch from models import get_model from data import load_config from utils import load_model_checkpoint import os # --- 1. LOAD CONFIGS AND INITIALIZE MODEL (Run once) --- print("Loading configs and initializing models ONCE...") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") try: config_img = load_config("configs/image/mirage.yaml") config_multi = load_config("configs/multimodal/mirage.yaml") mirage_img = get_model(config_img).to(device) # --- 2. Load Model Checkpoint (Run once) --- print("Loading model checkpoint ONCE...") checkpoint_path_img = config_multi['training']['image_model_path'] if not os.path.exists(checkpoint_path_img): print(f"FATAL ERROR: Model checkpoint not found at {checkpoint_path_img}") exit() mirage_img, _ = load_model_checkpoint(mirage_img, checkpoint_path_img) print(f"Loaded image model from {checkpoint_path_img}") # --- 3. Set to Evaluation Mode (Run once) --- mirage_img.eval() print("Model set to evaluation mode.") except Exception as e: print(f"FATAL ERROR during model initialization or checkpoint loading: {e}") exit() # --- 4. PREDICTION FUNCTION FROM .PT FILE --- def predict_authenticity_from_pt(pt_file_path, model, device): """ Load encoding data from a .pt file and predict authenticity (real/fake). Args: pt_file_path (str): Path to the .pt file containing the [1, 301] tensor. model (torch.nn.Module): The image model with the checkpoint loaded and in eval mode. device (str): The device, 'cuda' or 'cpu'. Returns: tuple: (probability_fake, prediction_label) - probability_fake (float): The probability that the image is FAKE (0.0 to 1.0). - prediction_label (str): The predicted label ("real" or "fake"). Returns (None, None) on error. """ print(f"\n--- Processing: {pt_file_path} ---") # 4.1. Check if file exists if not os.path.exists(pt_file_path): print(f"ERROR: Input file not found at {pt_file_path}") return None, None # 4.2. Load encoding data try: image_encodings = torch.load(pt_file_path).to(device) # Check shape (must be [1, 301]) if image_encodings.shape != (1, 301): print(f"ERROR: Expected tensor shape [1, 301], but got {image_encodings.shape} from {pt_file_path}") return None, None print(f"Loaded image data (301-dim) with shape: {image_encodings.shape}") except Exception as e: print(f"ERROR loading or checking data from {pt_file_path}: {e}") return None, None # 4.3. Run Inference (Prediction) try: with torch.no_grad(): output_logits_img = model(image_encodings) # Model ảnh nhận [1, 301] except Exception as e: print(f"ERROR during model inference for {pt_file_path}: {e}") return None, None # 4.4. Process results (Logic: 1 = fake) probs_img = torch.sigmoid(output_logits_img) probability_fake = probs_img.squeeze().item() # Probability the image is FAKE # Decide label based on a 0.5 threshold (Inverted logic) # If "fake" probability >= 0.5 -> fake # If "fake" probability < 0.5 -> real prediction_label = "fake" if probability_fake >= 0.5 else "real" print(f"Raw output (logits interpreted as 'fake' logit): {output_logits_img.squeeze().item():.4f}") # Update explanatory string print(f"Probability (0=real, 1=fake): {probability_fake:.4f}") print(f"Predicted Label: {prediction_label}") # Return the fake probability and the label return probability_fake, prediction_label # --- 5. HOW TO USE THE FUNCTION --- if __name__ == "__main__": # --- PROCESS A SINGLE FILE --- # Replace with the path to the .pt file you want to check input_pt_path_single = "encodings/predictions/image/merged/my_single_image_dir/real.pt" print("\n--- Processing a single file ---") prob_fake_single, label_single = predict_authenticity_from_pt(input_pt_path_single, mirage_img, device) if prob_fake_single is not None and label_single is not None: print(f"\nFinal result for {input_pt_path_single}: Probability Fake={prob_fake_single:.4f}, Label='{label_single}'") else: print(f"\nFailed to process {input_pt_path_single}.") print("\n" + "="*50 + "\n") # Add a divider line # --- EXAMPLE: PROCESS MULTIPLE FILES --- pt_files_to_check = [ "encodings/predictions/image/merged/my_single_image_dir/real.pt", # Replace with the real file path # "encodings/predictions/image/merged/another_dir/fake_image.pt", # THIS LINE WAS REMOVED "path/to/nonexistent.pt" # Example of a missing file ] print("\n--- Processing multiple files ---") results = {} for file_path in pt_files_to_check: prob_fake, label = predict_authenticity_from_pt(file_path, mirage_img, device) results[file_path] = (prob_fake, label) # Store results in a dictionary print("\n--- Summary ---") for file, (prob_fake, label) in results.items(): if prob_fake is not None: print(f"{file}: Prob Fake={prob_fake:.4f}, Label='{label}'") else: print(f"{file}: Processing FAILED") print("\nScript finished.")