import runpod import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel import os import logging from typing import Dict, Any, List logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MultiOutputClassifier(nn.Module): """Multi-output classifier""" def __init__(self, encoder, hidden_size, num_classes=3, num_levels=3): super().__init__() self.encoder = encoder self.classifiers = nn.ModuleList([ nn.Linear(hidden_size, num_levels) for _ in range(num_classes) ]) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() sum_hidden = torch.sum(hidden_states * attention_mask_expanded, dim=1) sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) pooled_output = sum_hidden / sum_mask logits = [classifier(pooled_output) for classifier in self.classifiers] logits = torch.stack(logits, dim=1) return logits # Global model instance model = None tokenizer = None device = None def load_model(): """Load model once at startup""" global model, tokenizer, device model_path = "/app/model" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Loading model from {model_path}") logger.info(f"Using device: {device}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load encoder encoder = AutoModel.from_pretrained( model_path, use_safetensors=True, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) # Initialize model hidden_size = encoder.config.hidden_size model = MultiOutputClassifier( encoder=encoder, hidden_size=hidden_size, num_classes=3, num_levels=3 ) # Load classification heads classifier_heads_path = os.path.join(model_path, "classifier_heads.pt") if os.path.exists(classifier_heads_path): logger.info(f"Loading classification heads") checkpoint = torch.load(classifier_heads_path, map_location=device) classifiers_list = checkpoint['classifiers'] for i, classifier in enumerate(model.classifiers): classifier.load_state_dict(classifiers_list[i]) model.to(device) model.eval() logger.info("✓ Model loaded and ready!") return model def handler(job: Dict[str, Any]) -> Dict[str, Any]: """ RunPod serverless handler Input format: { "input": { "text": "Your text here" OR ["text1", "text2"], "max_length": 128, # optional "return_scores": true # optional } } """ try: job_input = job["input"] # Extract inputs text_input = job_input.get("text", job_input.get("inputs", "")) max_length = job_input.get("max_length", 128) return_scores = job_input.get("return_scores", True) # Handle both single string and list if isinstance(text_input, str): texts = [text_input] else: texts = text_input # Tokenize encoded = tokenizer( texts, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt' ) input_ids = encoded['input_ids'].to(device) attention_mask = encoded['attention_mask'].to(device) # Inference with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.softmax(logits, dim=2) preds = torch.argmax(logits, dim=2) # Format results class_names = ['vyhruzky', 'vulgarity', 'rasismus'] level_labels = {0: 'none', 1: 'moderate', 2: 'severe'} results = [] for i in range(len(texts)): result = {} for j, class_name in enumerate(class_names): pred_class = preds[i, j].item() pred_prob = probs[i, j, pred_class].item() result[class_name] = { "label": level_labels[pred_class], "level": pred_class } if return_scores: result[class_name]["score"] = round(pred_prob, 4) results.append(result) return {"output": results} except Exception as e: logger.error(f"Error in handler: {str(e)}") return {"error": str(e)} if __name__ == "__main__": logger.info("Starting RunPod serverless handler...") load_model() logger.info("Starting RunPod serverless worker...") runpod.serverless.start({"handler": handler})