enuma-elis commited on
Commit
a3bfc6c
·
verified ·
1 Parent(s): 599e4f8

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +25 -0
  2. inference_server.py +164 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ RUN pip install --no-cache-dir \
7
+ transformers>=4.35.0 \
8
+ safetensors>=0.4.0 \
9
+ fastapi>=0.104.0 \
10
+ uvicorn[standard]>=0.24.0 \
11
+ pydantic>=2.0.0
12
+
13
+ # Copy model files
14
+ COPY . /app/model
15
+
16
+ # Copy inference server script
17
+ COPY inference_server.py /app/
18
+
19
+ EXPOSE 8080
20
+
21
+ # Health check
22
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
23
+ CMD curl -f http://localhost:8080/health || exit 1
24
+
25
+ CMD ["uvicorn", "inference_server:app", "--host", "0.0.0.0", "--port", "8080", "--workers", "1"]
inference_server.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from typing import List, Dict, Any, Optional, Union
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoTokenizer, AutoModel
8
+ import os
9
+ import logging
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI(title="Czech Text Classification API")
15
+
16
+ class InferenceRequest(BaseModel):
17
+ inputs: Union[str, List[str]]
18
+ parameters: Optional[Dict[str, Any]] = {}
19
+
20
+ class MultiOutputClassifier(nn.Module):
21
+ def __init__(self, encoder, hidden_size, num_classes=3, num_levels=3):
22
+ super().__init__()
23
+ self.encoder = encoder
24
+ self.classifiers = nn.ModuleList([
25
+ nn.Linear(hidden_size, num_levels) for _ in range(num_classes)
26
+ ])
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
30
+ hidden_states = outputs.last_hidden_state
31
+
32
+ attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
33
+ sum_hidden = torch.sum(hidden_states * attention_mask_expanded, dim=1)
34
+ sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
35
+ pooled_output = sum_hidden / sum_mask
36
+
37
+ logits = [classifier(pooled_output) for classifier in self.classifiers]
38
+ logits = torch.stack(logits, dim=1)
39
+
40
+ return logits
41
+
42
+ # Global variables
43
+ model = None
44
+ tokenizer = None
45
+ device = None
46
+
47
+ @app.on_event("startup")
48
+ async def load_model():
49
+ global model, tokenizer, device
50
+
51
+ model_path = "/app/model"
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ logger.info(f"Loading model from {model_path}")
55
+ logger.info(f"Using device: {device}")
56
+
57
+ # Load tokenizer
58
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+ # Load encoder
63
+ encoder = AutoModel.from_pretrained(
64
+ model_path,
65
+ use_safetensors=True,
66
+ trust_remote_code=True,
67
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
68
+ )
69
+
70
+ # Initialize model
71
+ hidden_size = encoder.config.hidden_size
72
+ model = MultiOutputClassifier(
73
+ encoder=encoder,
74
+ hidden_size=hidden_size,
75
+ num_classes=3,
76
+ num_levels=3
77
+ )
78
+
79
+ # Load classification heads
80
+ classifier_heads_path = os.path.join(model_path, "classifier_heads.pt")
81
+ if os.path.exists(classifier_heads_path):
82
+ logger.info(f"Loading classification heads from {classifier_heads_path}")
83
+ checkpoint = torch.load(classifier_heads_path, map_location=device)
84
+ classifiers_list = checkpoint['classifiers']
85
+
86
+ for i, classifier in enumerate(model.classifiers):
87
+ classifier.load_state_dict(classifiers_list[i])
88
+ logger.info(f"Loaded classifier {i}")
89
+ else:
90
+ logger.warning("classifier_heads.pt not found!")
91
+
92
+ model.to(device)
93
+ model.eval()
94
+ logger.info("Model loaded and ready!")
95
+
96
+ @app.get("/health")
97
+ async def health():
98
+ if model is None:
99
+ return JSONResponse(
100
+ status_code=503,
101
+ content={"status": "loading", "model_loaded": False}
102
+ )
103
+ return {"status": "healthy", "model_loaded": True}
104
+
105
+ @app.post("/")
106
+ async def predict(request: InferenceRequest):
107
+ """Main inference endpoint - HuggingFace compatible"""
108
+ if model is None:
109
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
110
+
111
+ # Handle input
112
+ inputs = request.inputs if isinstance(request.inputs, list) else [request.inputs]
113
+ max_length = request.parameters.get("max_length", 128)
114
+ return_scores = request.parameters.get("return_scores", True)
115
+
116
+ # Tokenize
117
+ encoded = tokenizer(
118
+ inputs,
119
+ truncation=True,
120
+ padding='max_length',
121
+ max_length=max_length,
122
+ return_tensors='pt'
123
+ )
124
+
125
+ input_ids = encoded['input_ids'].to(device)
126
+ attention_mask = encoded['attention_mask'].to(device)
127
+
128
+ # Inference
129
+ with torch.no_grad():
130
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
131
+ probs = torch.softmax(logits, dim=2)
132
+ preds = torch.argmax(logits, dim=2)
133
+
134
+ # Format results
135
+ class_names = ['vyhruzky', 'vulgarity', 'rasismus']
136
+ level_labels = {0: 'none', 1: 'moderate', 2: 'severe'}
137
+
138
+ results = []
139
+ for i in range(len(inputs)):
140
+ result = {}
141
+ for j, class_name in enumerate(class_names):
142
+ pred_class = preds[i, j].item()
143
+ pred_prob = probs[i, j, pred_class].item()
144
+
145
+ result[class_name] = {
146
+ "label": level_labels[pred_class],
147
+ "level": pred_class
148
+ }
149
+
150
+ if return_scores:
151
+ result[class_name]["score"] = round(pred_prob, 4)
152
+
153
+ results.append(result)
154
+
155
+ return results
156
+
157
+ @app.post("/predict")
158
+ async def predict_alt(request: InferenceRequest):
159
+ """Alternative endpoint"""
160
+ return await predict(request)
161
+
162
+ if __name__ == "__main__":
163
+ import uvicorn
164
+ uvicorn.run(app, host="0.0.0.0", port=8080)