from typing import Dict, List, Any import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel import os class MultiOutputClassifier(nn.Module): """Multi-output classifier matching your training code""" def __init__(self, encoder, hidden_size, num_classes=3, num_levels=3): super().__init__() self.encoder = encoder # Create 3 separate classification heads self.classifiers = nn.ModuleList([ nn.Linear(hidden_size, num_levels) for _ in range(num_classes) ]) def forward(self, input_ids, attention_mask): # Get encoder output outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state # Mean pooling with attention mask attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() sum_hidden = torch.sum(hidden_states * attention_mask_expanded, dim=1) sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) pooled_output = sum_hidden / sum_mask # Get predictions for each class logits = [classifier(pooled_output) for classifier in self.classifiers] logits = torch.stack(logits, dim=1) # Shape: (batch_size, 3, 3) return logits class EndpointHandler(): # Back to original name def __init__(self, path: str = ""): """ Initialize the handler with the fine-tuned model Args: path: Path to the model directory """ print(f"[INIT] Initializing handler from path: {path}") print(f"[INIT] Files in directory: {os.listdir(path) if os.path.exists(path) else 'PATH NOT FOUND'}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INIT] Using device: {self.device}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( path, trust_remote_code=True ) # Ensure pad token exists if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print("[INIT] Tokenizer loaded successfully") # Load the base encoder (this loads from sharded safetensors automatically) print("[INIT] Loading base encoder from sharded safetensors...") encoder = AutoModel.from_pretrained( path, use_safetensors=True, trust_remote_code=True, torch_dtype=torch.bfloat16, ) print(f"[INIT] Encoder loaded, hidden_size: {encoder.config.hidden_size}") # Initialize the full model with classification heads hidden_size = encoder.config.hidden_size self.model = MultiOutputClassifier( encoder=encoder, hidden_size=hidden_size, num_classes=3, num_levels=3 ) # Load the classification heads from classifier_heads.pt classifier_heads_path = os.path.join(path, "classifier_heads.pt") if os.path.exists(classifier_heads_path): print(f"[INIT] Loading classification heads from {classifier_heads_path}") try: # Load the saved checkpoint checkpoint = torch.load(classifier_heads_path, map_location=self.device) # The checkpoint contains a list of state_dicts classifiers_list = checkpoint['classifiers'] config = checkpoint.get('config', {}) print(f"[INIT] Checkpoint config: {config}") print(f"[INIT] Found {len(classifiers_list)} classifiers") # Load each classifier's state dict for i, classifier in enumerate(self.model.classifiers): classifier.load_state_dict(classifiers_list[i]) print(f"[INIT] ✓ Loaded classifier {i} (weight shape: {classifier.weight.shape})") print("[INIT] ✓✓✓ Successfully loaded all classification heads!") except Exception as e: print(f"[INIT] ⚠ Error loading classifier heads: {e}") import traceback traceback.print_exc() else: print(f"[INIT] ⚠⚠⚠ WARNING: classifier_heads.pt not found at {classifier_heads_path}") print("[INIT] Model will use random classifier heads!") self.model.to(self.device) self.model.eval() print(f"[INIT] Model moved to {self.device} and set to eval mode") # Class names for output self.class_names = ['vyhruzky', 'vulgarity', 'rasismus'] self.level_labels = { 0: 'none', 1: 'moderate', 2: 'severe' } print("[INIT] ✓✓✓ Handler initialization complete!") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference requests Args: data: Dictionary containing: - inputs: str or List[str] - text(s) to classify - return_scores: bool (optional) - whether to return confidence scores - max_length: int (optional) - max token length, default 128 Returns: List of predictions with format: [ { "vyhruzky": {"label": "none", "level": 0, "score": 0.95}, "vulgarity": {"label": "moderate", "level": 1, "score": 0.87}, "rasismus": {"label": "none", "level": 0, "score": 0.98} } ] """ print(f"[CALL] Received data: {data}") # Get inputs inputs = data.pop("inputs", data) return_scores = data.pop("return_scores", True) max_length = data.pop("max_length", 128) # Handle both single string and list of strings if isinstance(inputs, str): inputs = [inputs] print(f"[CALL] Processing {len(inputs)} input(s)") # Tokenize encoded = self.tokenizer( inputs, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt' ) input_ids = encoded['input_ids'].to(self.device) attention_mask = encoded['attention_mask'].to(self.device) print(f"[CALL] Input shape: {input_ids.shape}") # Get predictions with torch.no_grad(): logits = self.model(input_ids=input_ids, attention_mask=attention_mask) print(f"[CALL] Logits shape: {logits.shape}") print(f"[CALL] Logits sample: {logits[0]}") # Get predicted classes and probabilities probs = torch.softmax(logits, dim=2) # (batch_size, 3, 3) preds = torch.argmax(logits, dim=2) # (batch_size, 3) print(f"[CALL] Predictions: {preds}") print(f"[CALL] Probabilities: {probs[0]}") # Format results results = [] for i in range(len(inputs)): result = {} for j, class_name in enumerate(self.class_names): pred_class = preds[i, j].item() pred_prob = probs[i, j, pred_class].item() result[class_name] = { "label": self.level_labels[pred_class], "level": pred_class } if return_scores: result[class_name]["score"] = round(pred_prob, 4) results.append(result) print(f"[CALL] Returning results: {results}") return results