Spaces:
Paused
Paused
| import re | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from unidecode import unidecode | |
| from models import * | |
| tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") | |
| def preprocess(x): | |
| """Preprocess input string x""" | |
| s = unidecode(x) | |
| s = str.lower(s) | |
| s = re.sub(r"\[[a-z]+\]","", s) | |
| s = re.sub(r"\*","", s) | |
| s = re.sub(r"[^a-zA-Z0-9]+"," ",s) | |
| s = re.sub(r" +"," ",s) | |
| s = re.sub(r"(.)\1+",r"\1",s) | |
| return s | |
| label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] | |
| def ssl_predict(in_text, model_type): | |
| """main predict function""" | |
| preprocessed = preprocess(in_text) | |
| toks = tok( | |
| preprocessed, | |
| padding="max_length", | |
| max_length=96, | |
| truncation=True, | |
| return_tensors="tf" | |
| ) | |
| if model_type == "freematch": | |
| model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") | |
| model.cls_head.load_weights("./checkpoints/freematch_tune") | |
| preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
| probs = list(preds[0].numpy()) | |
| return {k:v for k, v in zip(label_names, probs)} | |
| with gr.Blocks() as ssl_interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_text = gr.Textbox(label="Input text") | |
| model_list = gr.Dropdown( | |
| choices=["fixmatch", "freematch", "mixmatch"], | |
| max_choices=1, | |
| label="Training method", | |
| allow_custom_value=False, | |
| info="Select trained model according to different SSL techniques from paper", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button(value="Clear") | |
| submit_btn = gr.Button(value="Submit") | |
| with gr.Column(): | |
| out_field = gr.Label(num_top_classes=4,label="Prediction") | |
| submit_btn.click( | |
| fn=ssl_predict, | |
| inputs=[in_text, model_list], | |
| outputs=[out_field] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None for _ in range(2)], | |
| inputs=None, | |
| outputs=[in_text, out_field] | |
| ) | |
| ssl_interface.launch(server_name="0.0.0.0", server_port=7860) | |