import os import torch from fastapi import FastAPI from pydantic import BaseModel from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig app = FastAPI() print("Loading model...") base_model_id = "google/txgemma-2b-predict" adapter_model_id = "shalindasilva1/SOAPGemma" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) base_model = AutoModelForCausalLM.from_pretrained( base_model_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, token=os.getenv("HF_TOKEN") ) tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=os.getenv("HF_TOKEN")) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = PeftModel.from_pretrained(base_model, adapter_model_id, token=os.getenv("HF_TOKEN")) model.eval() print("Model loaded.") class SOAPRequest(BaseModel): transcript: str @app.get("/health") def health(): return {"status": "ok"} @app.post("/generate-soap") def generate_soap(req: SOAPRequest): input_text = f"dialogue: {req.transcript} soap_note:" inputs = tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=1536 ).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, num_beams=1, temperature=0.7, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) soap_note = generated_text.split("soap_note:")[-1].strip() return {"soap_note": soap_note}