Whisper-Farsi / app.py
AmirMohseni's picture
Update app.py
659131a verified
import gradio as gr
from transformers import pipeline
import torch
import spaces
# --- 1. Setup and Global Definitions ---
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
MODELS = {
"Whisper Small": "AmirMohseni/whisper-small-persian-bf16",
"Whisper Large v3": "AmirMohseni/whisper-large-v3-persian-bf16"
}
model_pipelines = {}
# --- 2. Model Loading Function ---
def load_model(model_name):
model_id = MODELS[model_name]
if model_id not in model_pipelines:
print(f"Loading model: {model_name}...")
pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
torch_dtype="auto",
device=device,
)
model_pipelines[model_id] = pipe
print(f"Model '{model_name}' loaded successfully.")
return model_pipelines[model_id]
# --- 3. Main Transcription Function ---
@spaces.GPU(duration=90)
def transcribe(audio, model_name):
# 'audio' is now a filepath string again
if audio is None:
gr.Warning("No audio recorded. Please record your voice first.")
return ""
selected_pipe = load_model(model_name)
print(f"Transcribing with '{model_name}'...")
# The pipeline now receives the filepath directly
result = selected_pipe(audio, generate_kwargs={"language": "persian", "task": "transcribe"})
transcription = result["text"]
print(f"Transcription result: {transcription}")
return transcription
# --- 4. Pre-load the Default Model ---
print("Pre-loading the default model ('Whisper Large v3')...")
load_model("Whisper Large v3")
print("Default model pre-loaded. The interface is ready.")
# --- 5. Gradio Interface Definition ---
iface = gr.Interface(
fn=transcribe,
inputs=[
# Reverted the type back to "filepath"
gr.Audio(sources=["microphone"], type="filepath", label="Record Audio 🎤"), # <-- REVERTED
gr.Radio(
choices=list(MODELS.keys()),
value="Whisper Large v3",
label="Choose Model",
info="The 'Large' model is more accurate but slower. The 'Small' model is faster."
)
],
outputs=gr.Textbox(label="Transcription", lines=5),
title="Whisper Farsi 🎙️",
description="Real-time Persian speech recognition. Choose a model, press 'Record Audio', and start speaking.",
allow_flagging="never"
)
# --- 6. Launch the Application ---
iface.launch()