from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import List, Dict, Any, Optional, Union import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel import os import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Czech Text Classification API") class InferenceRequest(BaseModel): inputs: Union[str, List[str]] parameters: Optional[Dict[str, Any]] = {} class MultiOutputClassifier(nn.Module): def __init__(self, encoder, hidden_size, num_classes=3, num_levels=3): super().__init__() self.encoder = encoder self.classifiers = nn.ModuleList([ nn.Linear(hidden_size, num_levels) for _ in range(num_classes) ]) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = outputs.last_hidden_state attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() sum_hidden = torch.sum(hidden_states * attention_mask_expanded, dim=1) sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) pooled_output = sum_hidden / sum_mask logits = [classifier(pooled_output) for classifier in self.classifiers] logits = torch.stack(logits, dim=1) return logits # Global variables model = None tokenizer = None device = None @app.on_event("startup") async def load_model(): global model, tokenizer, device model_path = "/app/model" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Loading model from {model_path}") logger.info(f"Using device: {device}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load encoder encoder = AutoModel.from_pretrained( model_path, use_safetensors=True, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, ) # Initialize model hidden_size = encoder.config.hidden_size model = MultiOutputClassifier( encoder=encoder, hidden_size=hidden_size, num_classes=3, num_levels=3 ) # Load classification heads classifier_heads_path = os.path.join(model_path, "classifier_heads.pt") if os.path.exists(classifier_heads_path): logger.info(f"Loading classification heads from {classifier_heads_path}") checkpoint = torch.load(classifier_heads_path, map_location=device) classifiers_list = checkpoint['classifiers'] for i, classifier in enumerate(model.classifiers): classifier.load_state_dict(classifiers_list[i]) logger.info(f"Loaded classifier {i}") else: logger.warning("classifier_heads.pt not found!") model.to(device) model.eval() logger.info("Model loaded and ready!") @app.get("/health") async def health(): if model is None: return JSONResponse( status_code=503, content={"status": "loading", "model_loaded": False} ) return {"status": "healthy", "model_loaded": True} @app.post("/") async def predict(request: InferenceRequest): """Main inference endpoint - HuggingFace compatible""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded yet") # Handle input inputs = request.inputs if isinstance(request.inputs, list) else [request.inputs] max_length = request.parameters.get("max_length", 128) return_scores = request.parameters.get("return_scores", True) # Tokenize encoded = tokenizer( inputs, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt' ) input_ids = encoded['input_ids'].to(device) attention_mask = encoded['attention_mask'].to(device) # Inference with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.softmax(logits, dim=2) preds = torch.argmax(logits, dim=2) # Format results class_names = ['vyhruzky', 'vulgarity', 'rasismus'] level_labels = {0: 'none', 1: 'moderate', 2: 'severe'} results = [] for i in range(len(inputs)): result = {} for j, class_name in enumerate(class_names): pred_class = preds[i, j].item() pred_prob = probs[i, j, pred_class].item() result[class_name] = { "label": level_labels[pred_class], "level": pred_class } if return_scores: result[class_name]["score"] = round(pred_prob, 4) results.append(result) return results @app.post("/predict") async def predict_alt(request: InferenceRequest): """Alternative endpoint""" return await predict(request) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8080)