enuma-elis commited on
Commit
d331aa7
·
verified ·
1 Parent(s): 1b67745

Upload runpod_handler.py

Browse files
Files changed (1) hide show
  1. runpod_handler.py +164 -0
runpod_handler.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpod
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import os
6
+ import logging
7
+ from typing import Dict, Any, List
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class MultiOutputClassifier(nn.Module):
13
+ """Multi-output classifier"""
14
+ def __init__(self, encoder, hidden_size, num_classes=3, num_levels=3):
15
+ super().__init__()
16
+ self.encoder = encoder
17
+ self.classifiers = nn.ModuleList([
18
+ nn.Linear(hidden_size, num_levels) for _ in range(num_classes)
19
+ ])
20
+
21
+ def forward(self, input_ids, attention_mask):
22
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
23
+ hidden_states = outputs.last_hidden_state
24
+
25
+ attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
26
+ sum_hidden = torch.sum(hidden_states * attention_mask_expanded, dim=1)
27
+ sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
28
+ pooled_output = sum_hidden / sum_mask
29
+
30
+ logits = [classifier(pooled_output) for classifier in self.classifiers]
31
+ logits = torch.stack(logits, dim=1)
32
+
33
+ return logits
34
+
35
+ # Global model instance
36
+ model = None
37
+ tokenizer = None
38
+ device = None
39
+
40
+ def load_model():
41
+ """Load model once at startup"""
42
+ global model, tokenizer, device
43
+
44
+ model_path = "/app/model"
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ logger.info(f"Loading model from {model_path}")
48
+ logger.info(f"Using device: {device}")
49
+
50
+ # Load tokenizer
51
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
52
+ if tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+
55
+ # Load encoder
56
+ encoder = AutoModel.from_pretrained(
57
+ model_path,
58
+ use_safetensors=True,
59
+ trust_remote_code=True,
60
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
61
+ )
62
+
63
+ # Initialize model
64
+ hidden_size = encoder.config.hidden_size
65
+ model = MultiOutputClassifier(
66
+ encoder=encoder,
67
+ hidden_size=hidden_size,
68
+ num_classes=3,
69
+ num_levels=3
70
+ )
71
+
72
+ # Load classification heads
73
+ classifier_heads_path = os.path.join(model_path, "classifier_heads.pt")
74
+ if os.path.exists(classifier_heads_path):
75
+ logger.info(f"Loading classification heads")
76
+ checkpoint = torch.load(classifier_heads_path, map_location=device)
77
+ classifiers_list = checkpoint['classifiers']
78
+
79
+ for i, classifier in enumerate(model.classifiers):
80
+ classifier.load_state_dict(classifiers_list[i])
81
+
82
+ model.to(device)
83
+ model.eval()
84
+ logger.info("✓ Model loaded and ready!")
85
+
86
+ return model
87
+
88
+ def handler(job: Dict[str, Any]) -> Dict[str, Any]:
89
+ """
90
+ RunPod serverless handler
91
+
92
+ Input format:
93
+ {
94
+ "input": {
95
+ "text": "Your text here" OR ["text1", "text2"],
96
+ "max_length": 128, # optional
97
+ "return_scores": true # optional
98
+ }
99
+ }
100
+ """
101
+ try:
102
+ job_input = job["input"]
103
+
104
+ # Extract inputs
105
+ text_input = job_input.get("text", job_input.get("inputs", ""))
106
+ max_length = job_input.get("max_length", 128)
107
+ return_scores = job_input.get("return_scores", True)
108
+
109
+ # Handle both single string and list
110
+ if isinstance(text_input, str):
111
+ texts = [text_input]
112
+ else:
113
+ texts = text_input
114
+
115
+ # Tokenize
116
+ encoded = tokenizer(
117
+ texts,
118
+ truncation=True,
119
+ padding='max_length',
120
+ max_length=max_length,
121
+ return_tensors='pt'
122
+ )
123
+
124
+ input_ids = encoded['input_ids'].to(device)
125
+ attention_mask = encoded['attention_mask'].to(device)
126
+
127
+ # Inference
128
+ with torch.no_grad():
129
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
130
+ probs = torch.softmax(logits, dim=2)
131
+ preds = torch.argmax(logits, dim=2)
132
+
133
+ # Format results
134
+ class_names = ['vyhruzky', 'vulgarity', 'rasismus']
135
+ level_labels = {0: 'none', 1: 'moderate', 2: 'severe'}
136
+
137
+ results = []
138
+ for i in range(len(texts)):
139
+ result = {}
140
+ for j, class_name in enumerate(class_names):
141
+ pred_class = preds[i, j].item()
142
+ pred_prob = probs[i, j, pred_class].item()
143
+
144
+ result[class_name] = {
145
+ "label": level_labels[pred_class],
146
+ "level": pred_class
147
+ }
148
+
149
+ if return_scores:
150
+ result[class_name]["score"] = round(pred_prob, 4)
151
+
152
+ results.append(result)
153
+
154
+ return {"output": results}
155
+
156
+ except Exception as e:
157
+ logger.error(f"Error in handler: {str(e)}")
158
+ return {"error": str(e)}
159
+
160
+ if __name__ == "__main__":
161
+ logger.info("Starting RunPod serverless handler...")
162
+ load_model()
163
+ logger.info("Starting RunPod serverless worker...")
164
+ runpod.serverless.start({"handler": handler})