| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
| from torchvision import models, transforms |
| import torch.nn as nn |
| import os |
| import json |
| import cv2 |
| from PIL import Image |
| import gradio as gr |
|
|
| class MultimodalRiskBehaviorModel(nn.Module): |
| def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3): |
| super(MultimodalRiskBehaviorModel, self).__init__() |
|
|
| |
| self.text_model_name = text_model_name |
| self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2) |
| |
| |
| self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
| visual_feature_dim = self.visual_model.fc.in_features |
| self.visual_model.fc = nn.Identity() |
|
|
| |
| text_feature_dim = self.text_model.config.hidden_size |
| self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim) |
| self.dropout = nn.Dropout(dropout) |
| self.fc2 = nn.Linear(hidden_dim, 1) |
|
|
| def forward(self, encoding, frames): |
| input_ids = encoding['input_ids'].squeeze(1).to(device) |
| attention_mask = encoding['attention_mask'].squeeze(1).to(device) |
|
|
| |
| text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits |
| frames = frames.to(device) |
| |
| batch_size, num_frames, channels, height, width = frames.size() |
| frames = frames.view(batch_size * num_frames, channels, height, width) |
| visual_features = self.visual_model(frames) |
| visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1) |
|
|
| |
| combined_features = torch.cat((text_features, visual_features), dim=1) |
| x = self.dropout(torch.relu(self.fc1(combined_features))) |
| output = torch.sigmoid(self.fc2(x)) |
|
|
| return output |
|
|
| def save_pretrained(self, save_directory): |
| os.makedirs(save_directory, exist_ok=True) |
| torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) |
| config = { |
| "text_model_name": self.text_model_name, |
| "hidden_dim": self.fc1.out_features |
| } |
| with open(os.path.join(save_directory, 'config.json'), 'w') as f: |
| json.dump(config, f) |
|
|
| @classmethod |
| def from_pretrained(cls, load_directory, map_location=None): |
| if os.path.exists(load_directory): |
| config_path = os.path.join(load_directory, 'config.json') |
| state_dict_path = os.path.join(load_directory, 'pytorch_model.bin') |
|
|
| with open(config_path, 'r') as f: |
| config_dict = json.load(f) |
| model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"]) |
| state_dict = torch.load(state_dict_path, map_location=map_location) |
| model.load_state_dict(state_dict) |
| |
| else: |
| hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2) |
| model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size) |
| model.text_model = hf_model |
|
|
| return model |
|
|
| tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50') |
| model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
|
|
|
|
|
|
| |
| def load_frames_from_video(video_path, transform, num_frames=10): |
| cap = cv2.VideoCapture(video_path) |
| frames = [] |
| frame_count = 0 |
| while frame_count < num_frames: |
| success, frame = cap.read() |
| if not success: |
| break |
| |
| frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| frame = transform(frame) |
| frames.append(frame) |
| frame_count += 1 |
| cap.release() |
|
|
| |
| frames = torch.stack(frames) |
| frames = frames.unsqueeze(0) |
| return frames |
|
|
| def predict_video(model, video_path, text_input, tokenizer, transform): |
| try: |
| |
| model.eval() |
| |
| |
| encoding = tokenizer( |
| text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt' |
| ) |
| encoding = {key: val.to(device) for key, val in encoding.items()} |
| |
| |
| frames = load_frames_from_video(video_path, transform) |
| frames = frames.to(device) |
| |
| |
| print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}") |
| |
| |
| with torch.no_grad(): |
| output = model(encoding, frames) |
| |
| |
| prediction = (output.squeeze(-1) > 0.5).float() |
| |
| return prediction.item() |
| |
| except Exception as e: |
| print(f"Prediction error: {e}") |
| return "Error during prediction" |
|
|
|
|
|
|
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| |
| video_paths = [ |
| 'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM', |
| 'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n', |
| 'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj' |
| ] |
|
|
| video_captions = [ |
| "Everytime i start a diet كل مرة أحاول أبدأ ريجيم 😓 #dietmemes #funnyvideos #animetiktok", |
| "New sandwich from burger king 🍔👑 #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King", |
| "all workout guides l!nked in bi0 // honestly huge moment 😂 I’ve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp" |
| ] |
|
|
|
|
| def predict_risk(video_index): |
| video_path = video_paths[video_index] |
| text_input = video_captions[video_index] |
| |
| |
| prediction = predict_video(model, video_path, text_input, tokenizer, transform) |
| |
| |
| return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior" |
|
|
| |
| with gr.Blocks() as interface: |
| gr.Markdown("# Risk Behavior Prediction") |
| gr.Markdown("Select a video to classify its behavior as risky or not.") |
| |
| |
| video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video") |
|
|
| |
| def show_selected_video(choice): |
| idx = int(choice.split()[-1]) - 1 |
| return video_paths[idx], f"**Caption:** {video_captions[idx]}" |
|
|
| video_player = gr.Video(width=320, height=240) |
| caption_box = gr.Markdown() |
|
|
| video_selector.change( |
| fn=show_selected_video, |
| inputs=video_selector, |
| outputs=[video_player, caption_box] |
| ) |
| |
| |
| predict_button = gr.Button("Predict Risk") |
| output_text = gr.Textbox(label="Prediction") |
|
|
| predict_button.click( |
| fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1), |
| inputs=video_selector, |
| outputs=output_text |
| ) |
|
|
| |
| interface.launch() |