SSL_demo / app.py
Andrei-Iulian SĂCELEANU
fix audio-mixmatch
5506e99
Raw
History Blame
6.56 kB
import re
import gradio as gr
import librosa
import numpy as np
from transformers import AutoTokenizer,ViTImageProcessor
from unidecode import unidecode
from models import *
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
def preprocess(x):
"""Preprocess input string x"""
s = unidecode(x)
s = str.lower(s)
s = re.sub(r"\[[a-z]+\]","", s)
s = re.sub(r"\*","", s)
s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
s = re.sub(r" +"," ",s)
s = re.sub(r"(.)\1+",r"\1",s)
return s
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"]
def ssl_predict(in_text, model_type):
"""main predict function"""
preprocessed = preprocess(in_text)
toks = tok(
preprocessed,
padding="max_length",
max_length=96,
truncation=True,
return_tensors="tf"
)
preds = None
if model_type == "fixmatch":
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
model.load_weights("./checkpoints/fixmatch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "freematch":
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
model.cls_head.load_weights("./checkpoints/freematch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "mixmatch":
model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
model.cls_head.load_weights("./checkpoints/mixmatch")
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "contrastive_reg":
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
model.load_weights("./checkpoints/contrastive")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "label_propagation":
model = LPModel()
model.load_weights("./checkpoints/label_prop")
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(label_names, probs):
d[k] = float(v)
return d
def ssl_predict2(audio_file, model_type):
"""main predict function"""
signal, sr = librosa.load(audio_file.name, sr=16000)
length = 5 * 16000
if len(signal) < length:
signal = np.pad(signal,(0,length-len(signal)),'constant')
else:
signal = signal[:length]
spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128)
spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max)
spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max()
spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min)
spectrogram = spectrogram.astype("float32")
inputs = processor.preprocess(
np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1),
image_mean=(-3.05,-3.05,-3.05),
image_std=(2.33,2.33,2.33),
return_tensors="tf"
)
preds = None
if model_type == "fixmatch":
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch")
model.cls_head.load_weights("./checkpoints/audio_fixmatch")
preds, _ = model(inputs["pixel_values"], training=False)
elif model_type == "freematch":
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch")
model.cls_head.load_weights("./checkpoints/audio_freematch")
preds, _ = model(inputs["pixel_values"], training=False)
elif model_type == "mixmatch":
model = AudioMixMatch(encoder_name="andrei-saceleanu/vit-base-mixmatch")
model.cls_head.load_weights("./checkpoints/audio_mixmatch")
preds = model(inputs["pixel_values"], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(audio_label_names, probs):
d[k] = float(v)
return d
with gr.Blocks() as ssl_interface:
with gr.Tab("Text (RO-Offense)"):
with gr.Row():
with gr.Column():
in_text = gr.Textbox(label="Input text")
model_list = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn = gr.Button(value="Clear")
submit_btn = gr.Button(value="Submit")
with gr.Column():
out_field = gr.Label(num_top_classes=4, label="Prediction")
submit_btn.click(
fn=ssl_predict,
inputs=[in_text, model_list],
outputs=[out_field]
)
clear_btn.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[in_text, out_field],
queue=False
)
with gr.Tab("Audio (VocalSound)"):
with gr.Row():
with gr.Column():
audio_file = gr.File(
label="Input audio",
file_count="single",
file_types=["audio"]
)
model_list2 = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn2 = gr.Button(value="Clear")
submit_btn2 = gr.Button(value="Submit")
with gr.Column():
out_field2 = gr.Label(num_top_classes=6, label="Prediction")
submit_btn2.click(
fn=ssl_predict2,
inputs=[audio_file, model_list2],
outputs=[out_field2]
)
clear_btn2.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[audio_file, out_field2],
queue=False
)
ssl_interface.launch(server_name="0.0.0.0", server_port=7860)