heyIamUmair's picture
multilingual
d821926 verified
Raw
History Blame Contribute Delete
4.55 kB
# import gradio as gr
# import torch, numpy as np, soundfile as sf
# from transformers import WhisperProcessor, WhisperForConditionalGeneration
# import os
# auth_token = os.environ.get("HF_TOKEN") # gets secret token
# # ---- Settings ----
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# MODEL_DIR = "heyIamUmair/whisper-base-sindhi1" # HF repo name
# # ---- Load model & processor (with token if private) ----
# processor = WhisperProcessor.from_pretrained(MODEL_DIR, use_auth_token=auth_token)
# model = WhisperForConditionalGeneration.from_pretrained(MODEL_DIR, use_auth_token=auth_token).to(DEVICE).eval()
# model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="Sindhi", task="transcribe")
# # ---- Resample function ----
# def resample_to_16k(audio, sr):
# if sr == 16000:
# return audio, sr
# import torchaudio
# wav = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) # 1 x T
# wav = torchaudio.functional.resample(wav, sr, 16000)
# return wav.squeeze(0).cpu().numpy(), 16000
# # ---- Transcription function ----
# def transcribe(path):
# audio, sr = sf.read(path, always_2d=False)
# audio = audio.astype(np.float32) if audio.dtype != np.float32 else audio
# if audio.ndim > 1:
# audio = np.mean(audio, axis=1) # mono
# if audio.dtype != np.float32:
# maxv = np.iinfo(audio.dtype).max
# audio = (audio / maxv).astype(np.float32)
# audio, sr = resample_to_16k(audio, sr)
# inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
# input_features = inputs.input_features.to(DEVICE)
# with torch.no_grad():
# pred_ids = model.generate(input_features, max_length=225)
# text = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
# return text.strip()
# # ---- Gradio Interface ----
# demo = gr.Interface(
# fn=transcribe,
# inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
# outputs=gr.Textbox(lines=5, max_lines=20, interactive=True, label="Transcription"),
# title="Sindhi Speech-to-Text",
# description="Upload or record speech and get Sindhi transcription using Whisper fine-tuned model."
# )
# demo.launch()
import gradio as gr
import torch, numpy as np, soundfile as sf
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import os
auth_token = os.environ.get("HF_TOKEN") # gets secret token
# ---- Settings ----
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_DIR = "heyIamUmair/whisper-base-sindhi1" # HF repo name
# ---- Load model & processor (with token if private) ----
processor = WhisperProcessor.from_pretrained(MODEL_DIR, use_auth_token=auth_token)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_DIR, use_auth_token=auth_token).to(DEVICE).eval()
# ---- Resample function ----
def resample_to_16k(audio, sr):
if sr == 16000:
return audio, sr
import torchaudio
wav = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) # 1 x T
wav = torchaudio.functional.resample(wav, sr, 16000)
return wav.squeeze(0).cpu().numpy(), 16000
# ---- Transcription function ----
def transcribe(path, language):
audio, sr = sf.read(path, always_2d=False)
audio = audio.astype(np.float32) if audio.dtype != np.float32 else audio
if audio.ndim > 1:
audio = np.mean(audio, axis=1) # mono
if audio.dtype != np.float32:
maxv = np.iinfo(audio.dtype).max
audio = (audio / maxv).astype(np.float32)
audio, sr = resample_to_16k(audio, sr)
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to(DEVICE)
# Force selected language
forced_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
with torch.no_grad():
pred_ids = model.generate(input_features, max_length=225, forced_decoder_ids=forced_ids)
text = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
return text.strip()
# ---- Gradio Interface ----
demo = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources=["upload", "microphone"], type="filepath"),
gr.Dropdown(["Sindhi", "Urdu", "English"], value="Sindhi", label="Select Language")
],
outputs=gr.Textbox(lines=8, max_lines=40, interactive=True, label="Transcription"),
title="Multilingual Speech-to-Text",
description="Upload or record speech and transcribe into Sindhi, Urdu, or English using a fine-tuned Whisper model."
)
demo.launch()