"""MedGemma service for prescription extraction and explanation.""" import logging from typing import Optional import torch from PIL import Image from transformers import AutoModelForImageTextToText, AutoProcessor from .constants import MEDGEMMA_MODEL_ID from .prompts import EXTRACTION_PROMPT, EXPLANATION_PROMPT logger = logging.getLogger(__name__) class MedGemmaService: """Service for extracting medications from prescription images using MedGemma.""" def __init__(self, model, processor): """Initialize with pre-loaded model and processor.""" self.model = model self.processor = processor def extract_medications(self, image: Image.Image) -> str: """ Extract medication information from a prescription image. Args: image: PIL Image of the prescription Returns: Extracted medication information as text Raises: ValueError: If extraction fails """ try: messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": EXTRACTION_PROMPT}, ], } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=1024, do_sample=False, ) response = self.processor.decode( outputs[0][input_len:], skip_special_tokens=True ) return response.strip() except Exception as e: logger.error(f"Medication extraction failed: {e}") raise ValueError( "Could not read the prescription. Please try uploading a clearer image." ) def generate_explanation(self, medication_info: str) -> str: """ Generate a plain-language explanation of medications. Args: medication_info: Extracted medication information Returns: Patient-friendly explanation Raises: ValueError: If explanation generation fails """ try: prompt = EXPLANATION_PROMPT.format(medication_info=medication_info) messages = [{"role": "user", "content": prompt}] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=1024, do_sample=False, ) response = self.processor.decode( outputs[0][input_len:], skip_special_tokens=True ) return response.strip() except Exception as e: logger.error(f"Explanation generation failed: {e}", exc_info=True) raise ValueError( f"Could not generate explanation: {str(e)}" ) def load_medgemma_model(): """ Load MedGemma model and processor. Returns: Tuple of (model, processor) """ logger.info(f"Loading MedGemma model: {MEDGEMMA_MODEL_ID}") processor = AutoProcessor.from_pretrained(MEDGEMMA_MODEL_ID) model = AutoModelForImageTextToText.from_pretrained( MEDGEMMA_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) logger.info("MedGemma model loaded successfully") return model, processor