"""Translation service using Gemma 3 for multilingual support.""" import logging import torch from transformers import AutoModelForCausalLM, AutoTokenizer from .constants import SUPPORTED_LANGUAGES, TRANSLATION_MODEL_ID from .prompts import TRANSLATION_PROMPT logger = logging.getLogger(__name__) class TranslationService: """Service for translating text to supported languages using Gemma 3.""" def __init__(self, model, tokenizer): """Initialize with pre-loaded model and tokenizer.""" self.model = model self.tokenizer = tokenizer def generate_text(self, prompt: str) -> str: """ Generate text using Gemma 3 (for explanations). Args: prompt: Text prompt for generation Returns: Generated text Raises: ValueError: If generation fails """ try: messages = [{"role": "user", "content": prompt}] inputs = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=2048, do_sample=True, temperature=0.7, ) response = self.tokenizer.decode( outputs[0][input_len:], skip_special_tokens=True ) return response.strip() except Exception as e: logger.error(f"Text generation failed: {e}") raise ValueError("Could not generate text. Please try again.") def translate_text(self, text: str, target_language: str) -> str: """ Translate text to the target language. Args: text: Text to translate (in English) target_language: Target language display name (e.g., "Thai") Returns: Translated text Raises: ValueError: If translation fails or language not supported """ if target_language not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {target_language}") # If target is English, no translation needed if target_language == "English": return text try: prompt = TRANSLATION_PROMPT.format( target_language=target_language, text=text, ) return self.generate_text(prompt) except Exception as e: logger.error(f"Translation to {target_language} failed: {e}") raise ValueError( f"Could not translate to {target_language}. Please try again." ) def load_translation_model(): """ Load Gemma 3 model for translation. Returns: Tuple of (model, tokenizer) """ logger.info(f"Loading translation model: {TRANSLATION_MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_ID) model = AutoModelForCausalLM.from_pretrained( TRANSLATION_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) logger.info("Translation model loaded successfully") return model, tokenizer