Spaces:
Paused
Paused
| import torch | |
| from models import get_model | |
| from data import load_config | |
| from utils import load_model_checkpoint | |
| import os | |
| # --- 1. LOAD CONFIGS AND INITIALIZE MODEL (Run once) --- | |
| print("Loading configs and initializing models ONCE...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| try: | |
| config_img = load_config("configs/image/mirage.yaml") | |
| config_multi = load_config("configs/multimodal/mirage.yaml") | |
| mirage_img = get_model(config_img).to(device) | |
| # --- 2. Load Model Checkpoint (Run once) --- | |
| print("Loading model checkpoint ONCE...") | |
| checkpoint_path_img = config_multi['training']['image_model_path'] | |
| if not os.path.exists(checkpoint_path_img): | |
| print(f"FATAL ERROR: Model checkpoint not found at {checkpoint_path_img}") | |
| exit() | |
| mirage_img, _ = load_model_checkpoint(mirage_img, checkpoint_path_img) | |
| print(f"Loaded image model from {checkpoint_path_img}") | |
| # --- 3. Set to Evaluation Mode (Run once) --- | |
| mirage_img.eval() | |
| print("Model set to evaluation mode.") | |
| except Exception as e: | |
| print(f"FATAL ERROR during model initialization or checkpoint loading: {e}") | |
| exit() | |
| # --- 4. PREDICTION FUNCTION FROM .PT FILE --- | |
| def predict_authenticity_from_pt(pt_file_path, model, device): | |
| """ | |
| Load encoding data from a .pt file and predict authenticity (real/fake). | |
| Args: | |
| pt_file_path (str): Path to the .pt file containing the [1, 301] tensor. | |
| model (torch.nn.Module): The image model with the checkpoint loaded and in eval mode. | |
| device (str): The device, 'cuda' or 'cpu'. | |
| Returns: | |
| tuple: (probability_fake, prediction_label) | |
| - probability_fake (float): The probability that the image is FAKE (0.0 to 1.0). | |
| - prediction_label (str): The predicted label ("real" or "fake"). | |
| Returns (None, None) on error. | |
| """ | |
| print(f"\n--- Processing: {pt_file_path} ---") | |
| # 4.1. Check if file exists | |
| if not os.path.exists(pt_file_path): | |
| print(f"ERROR: Input file not found at {pt_file_path}") | |
| return None, None | |
| # 4.2. Load encoding data | |
| try: | |
| image_encodings = torch.load(pt_file_path).to(device) | |
| # Check shape (must be [1, 301]) | |
| if image_encodings.shape != (1, 301): | |
| print(f"ERROR: Expected tensor shape [1, 301], but got {image_encodings.shape} from {pt_file_path}") | |
| return None, None | |
| print(f"Loaded image data (301-dim) with shape: {image_encodings.shape}") | |
| except Exception as e: | |
| print(f"ERROR loading or checking data from {pt_file_path}: {e}") | |
| return None, None | |
| # 4.3. Run Inference (Prediction) | |
| try: | |
| with torch.no_grad(): | |
| output_logits_img = model(image_encodings) # Model ảnh nhận [1, 301] | |
| except Exception as e: | |
| print(f"ERROR during model inference for {pt_file_path}: {e}") | |
| return None, None | |
| # 4.4. Process results (Logic: 1 = fake) | |
| probs_img = torch.sigmoid(output_logits_img) | |
| probability_fake = probs_img.squeeze().item() # Probability the image is FAKE | |
| # Decide label based on a 0.5 threshold (Inverted logic) | |
| # If "fake" probability >= 0.5 -> fake | |
| # If "fake" probability < 0.5 -> real | |
| prediction_label = "fake" if probability_fake >= 0.5 else "real" | |
| print(f"Raw output (logits interpreted as 'fake' logit): {output_logits_img.squeeze().item():.4f}") | |
| # Update explanatory string | |
| print(f"Probability (0=real, 1=fake): {probability_fake:.4f}") | |
| print(f"Predicted Label: {prediction_label}") | |
| # Return the fake probability and the label | |
| return probability_fake, prediction_label | |
| # --- 5. HOW TO USE THE FUNCTION --- | |
| if __name__ == "__main__": | |
| # --- PROCESS A SINGLE FILE --- | |
| # Replace with the path to the .pt file you want to check | |
| input_pt_path_single = "encodings/predictions/image/merged/my_single_image_dir/real.pt" | |
| print("\n--- Processing a single file ---") | |
| prob_fake_single, label_single = predict_authenticity_from_pt(input_pt_path_single, mirage_img, device) | |
| if prob_fake_single is not None and label_single is not None: | |
| print(f"\nFinal result for {input_pt_path_single}: Probability Fake={prob_fake_single:.4f}, Label='{label_single}'") | |
| else: | |
| print(f"\nFailed to process {input_pt_path_single}.") | |
| print("\n" + "="*50 + "\n") # Add a divider line | |
| # --- EXAMPLE: PROCESS MULTIPLE FILES --- | |
| pt_files_to_check = [ | |
| "encodings/predictions/image/merged/my_single_image_dir/real.pt", # Replace with the real file path | |
| # "encodings/predictions/image/merged/another_dir/fake_image.pt", # THIS LINE WAS REMOVED | |
| "path/to/nonexistent.pt" # Example of a missing file | |
| ] | |
| print("\n--- Processing multiple files ---") | |
| results = {} | |
| for file_path in pt_files_to_check: | |
| prob_fake, label = predict_authenticity_from_pt(file_path, mirage_img, device) | |
| results[file_path] = (prob_fake, label) # Store results in a dictionary | |
| print("\n--- Summary ---") | |
| for file, (prob_fake, label) in results.items(): | |
| if prob_fake is not None: | |
| print(f"{file}: Prob Fake={prob_fake:.4f}, Label='{label}'") | |
| else: | |
| print(f"{file}: Processing FAILED") | |
| print("\nScript finished.") | |