enuma-elis's picture
Update handler.py
1fd845e verified
Raw
History Blame
7.99 kB
from typing import Dict, List, Any
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import os
class MultiOutputClassifier(nn.Module):
"""Multi-output classifier matching your training code"""
def __init__(self, encoder, hidden_size, num_classes=3, num_levels=3):
super().__init__()
self.encoder = encoder
# Create 3 separate classification heads
self.classifiers = nn.ModuleList([
nn.Linear(hidden_size, num_levels) for _ in range(num_classes)
])
def forward(self, input_ids, attention_mask):
# Get encoder output
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
# Mean pooling with attention mask
attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_hidden = torch.sum(hidden_states * attention_mask_expanded, dim=1)
sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
pooled_output = sum_hidden / sum_mask
# Get predictions for each class
logits = [classifier(pooled_output) for classifier in self.classifiers]
logits = torch.stack(logits, dim=1) # Shape: (batch_size, 3, 3)
return logits
class EndpointHandler(): # Back to original name
def __init__(self, path: str = ""):
"""
Initialize the handler with the fine-tuned model
Args:
path: Path to the model directory
"""
print(f"[INIT] Initializing handler from path: {path}")
print(f"[INIT] Files in directory: {os.listdir(path) if os.path.exists(path) else 'PATH NOT FOUND'}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INIT] Using device: {self.device}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True
)
# Ensure pad token exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print("[INIT] Tokenizer loaded successfully")
# Load the base encoder (this loads from sharded safetensors automatically)
print("[INIT] Loading base encoder from sharded safetensors...")
encoder = AutoModel.from_pretrained(
path,
use_safetensors=True,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
print(f"[INIT] Encoder loaded, hidden_size: {encoder.config.hidden_size}")
# Initialize the full model with classification heads
hidden_size = encoder.config.hidden_size
self.model = MultiOutputClassifier(
encoder=encoder,
hidden_size=hidden_size,
num_classes=3,
num_levels=3
)
# Load the classification heads from classifier_heads.pt
classifier_heads_path = os.path.join(path, "classifier_heads.pt")
if os.path.exists(classifier_heads_path):
print(f"[INIT] Loading classification heads from {classifier_heads_path}")
try:
# Load the saved checkpoint
checkpoint = torch.load(classifier_heads_path, map_location=self.device)
# The checkpoint contains a list of state_dicts
classifiers_list = checkpoint['classifiers']
config = checkpoint.get('config', {})
print(f"[INIT] Checkpoint config: {config}")
print(f"[INIT] Found {len(classifiers_list)} classifiers")
# Load each classifier's state dict
for i, classifier in enumerate(self.model.classifiers):
classifier.load_state_dict(classifiers_list[i])
print(f"[INIT] ✓ Loaded classifier {i} (weight shape: {classifier.weight.shape})")
print("[INIT] ✓✓✓ Successfully loaded all classification heads!")
except Exception as e:
print(f"[INIT] âš  Error loading classifier heads: {e}")
import traceback
traceback.print_exc()
else:
print(f"[INIT] âš âš âš  WARNING: classifier_heads.pt not found at {classifier_heads_path}")
print("[INIT] Model will use random classifier heads!")
self.model.to(self.device)
self.model.eval()
print(f"[INIT] Model moved to {self.device} and set to eval mode")
# Class names for output
self.class_names = ['vyhruzky', 'vulgarity', 'rasismus']
self.level_labels = {
0: 'none',
1: 'moderate',
2: 'severe'
}
print("[INIT] ✓✓✓ Handler initialization complete!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests
Args:
data: Dictionary containing:
- inputs: str or List[str] - text(s) to classify
- return_scores: bool (optional) - whether to return confidence scores
- max_length: int (optional) - max token length, default 128
Returns:
List of predictions with format:
[
{
"vyhruzky": {"label": "none", "level": 0, "score": 0.95},
"vulgarity": {"label": "moderate", "level": 1, "score": 0.87},
"rasismus": {"label": "none", "level": 0, "score": 0.98}
}
]
"""
print(f"[CALL] Received data: {data}")
# Get inputs
inputs = data.pop("inputs", data)
return_scores = data.pop("return_scores", True)
max_length = data.pop("max_length", 128)
# Handle both single string and list of strings
if isinstance(inputs, str):
inputs = [inputs]
print(f"[CALL] Processing {len(inputs)} input(s)")
# Tokenize
encoded = self.tokenizer(
inputs,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
input_ids = encoded['input_ids'].to(self.device)
attention_mask = encoded['attention_mask'].to(self.device)
print(f"[CALL] Input shape: {input_ids.shape}")
# Get predictions
with torch.no_grad():
logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
print(f"[CALL] Logits shape: {logits.shape}")
print(f"[CALL] Logits sample: {logits[0]}")
# Get predicted classes and probabilities
probs = torch.softmax(logits, dim=2) # (batch_size, 3, 3)
preds = torch.argmax(logits, dim=2) # (batch_size, 3)
print(f"[CALL] Predictions: {preds}")
print(f"[CALL] Probabilities: {probs[0]}")
# Format results
results = []
for i in range(len(inputs)):
result = {}
for j, class_name in enumerate(self.class_names):
pred_class = preds[i, j].item()
pred_prob = probs[i, j, pred_class].item()
result[class_name] = {
"label": self.level_labels[pred_class],
"level": pred_class
}
if return_scores:
result[class_name]["score"] = round(pred_prob, 4)
results.append(result)
print(f"[CALL] Returning results: {results}")
return results