""" Gradio Demo for Local Translation Model. Uses VAGOsolutions LFM2.5 model for inference with GPU support and Flash Attention 2. """ import os import torch import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer # --- Initialization & Flash Attention 2 --- def install_fa2(): try: import flash_attn # noqa: F401 print("Flash Attention 2 already installed.") except ImportError: print("Installing Flash Attention 2...") os.system("pip install flash-attn --no-build-isolation") install_fa2() hf_token = os.getenv("HF_KEY") # --- Configuration --- MODEL_PATH = "VAGOsolutions/SauerkrautLM-Translator-LFM2.5-1.2B" LANGUAGE_NAMES = { "de": "German", "it": "Italian", "es": "Spanish", "fr": "French", "en": "English", } LANG_CHOICES = [ ("German", "de"), ("Italian", "it"), ("Spanish", "es"), ("French", "fr"), ("English", "en"), ] # --- Global State --- model = None tokenizer = None model_loaded = False def create_translation_prompt(text: str, target_lang: str) -> str: lang_name = LANGUAGE_NAMES.get(target_lang, "German") return f"Translate this text in {lang_name}:\n\n{text}" def load_model_if_needed(): """Loads the model once and keeps it in memory.""" global model, tokenizer, model_loaded if model_loaded and model is not None and tokenizer is not None: return print(f"Loading model: {MODEL_PATH}...") tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, trust_remote_code=True, token=hf_token ) try: model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", attn_implementation="flash_attention_2", token=hf_token ).eval() print("Using Flash Attention 2") except Exception as e: print(f"Flash Attention 2 initialization failed, falling back to default attention: {e}") model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", token=hf_token ).eval() model_loaded = True print("Model loaded successfully!") @spaces.GPU def translate(text: str, target_lang: str) -> tuple[str, str, str]: """Translate text using the local model on GPU.""" if not text.strip(): return "", "", "Please enter some text to translate." try: load_model_if_needed() device = "cuda" if torch.cuda.is_available() else "cpu" prompt = create_translation_prompt(text, target_lang) messages = [{"role": "user", "content": prompt}] input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer( input_text, return_tensors="pt", padding=True, truncation=True ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=2048, temperature=0.3, top_k=50, top_p=0.1, repetition_penalty=1.05, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] full_response = tokenizer.decode(generated_tokens, skip_special_tokens=True) return prompt, full_response, full_response.strip() except Exception as e: import traceback return "", "", f"Error: {str(e)}\n{traceback.format_exc()}" # --- Custom theme --- custom_theme = gr.themes.Base( primary_hue="slate", secondary_hue="slate", neutral_hue="slate", ) # --- Build Interface --- with gr.Blocks(title="LFM2 Translation") as demo: # Header gr.Markdown("# SauerkrautLM-Translator-LFM2.5-1.2b Demo") gr.Markdown("**Model:** `VAGOsolutions/SauerkrautLM-Translator-LFM2.5-1.2b`") # Main content with gr.Row(equal_height=True): with gr.Column(scale=1): input_text = gr.Textbox( placeholder="Enter text to translate...", lines=12, label="Source Text" ) with gr.Row(): target_lang = gr.Dropdown( choices=LANG_CHOICES, value="en", label="Target Language", scale=2 ) translate_btn = gr.Button( "Translate", variant="primary", scale=1 ) with gr.Column(scale=1): translation_output = gr.Textbox( lines=14, label="Translation", interactive=False ) # Technical details (collapsed by default) with gr.Accordion("Technical Details", open=False): with gr.Row(): prompt_display = gr.Textbox( lines=6, label="Prompt", interactive=False ) full_raw_response = gr.Textbox( lines=6, label="Raw Response", interactive=False ) # Examples with gr.Accordion("Examples", open=False): gr.Examples( examples=[ ["Just completed my first marathon in 3 hours and 45 minutes! 🏃‍♂️ Six months ago I could barely run 5km. This journey taught me that consistency beats intensity every single time. Grateful for everyone who supported me along the way. What's your next challenge?", "de"], ["BREAKING: The European Central Bank announced today a 0.25% interest rate cut, citing slowing inflation across the eurozone. Markets responded positively, with the DAX rising 1.2% in early trading. Analysts expect further cuts in the coming months.", "it"], ["To assemble the bookshelf: 1) Lay all pieces flat on the floor. 2) Attach the side panels to the base using the provided screws. 3) Insert the shelves at your desired height. 4) Secure the back panel with small nails. 5) Mount to wall for stability.", "fr"], ["La reunión de mañana ha sido cancelada debido a un conflicto de agenda. Por favor, confirmen su disponibilidad para el jueves a las 15:00. Necesitamos revisar el presupuesto del próximo trimestre antes del viernes.", "en"], ["Omas Apfelkuchen: 200g Mehl, 100g Butter, 80g Zucker, 1 Ei, 4 Äpfel. Teig kneten, in die Form drücken, Äpfel darauf verteilen, bei 180°C 45 Minuten backen. Mit Puderzucker bestäuben und warm servieren.", "es"], ], inputs=[input_text, target_lang], label="" ) # Event Handlers translate_btn.click( fn=translate, inputs=[input_text, target_lang], outputs=[prompt_display, full_raw_response, translation_output] ) input_text.submit( fn=translate, inputs=[input_text, target_lang], outputs=[prompt_display, full_raw_response, translation_output] ) if __name__ == "__main__": demo.queue().launch(debug=True, theme=custom_theme)