Spaces:
Build error
Build error
| import torch | |
| from transformers import BertTokenizer | |
| from regression_models import BERTRegression | |
| import gradio as gr | |
| max_len = 80 | |
| # Load tokenizer | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| # Load model architecture | |
| bertregressor = BERTRegression() | |
| bertregressor.load_state_dict(torch.load('bert_regression_model.pth', map_location=torch.device('cpu'))) | |
| bertregressor.eval() | |
| def predict_price(name, item_condition, category, brand_name, shipping_included, item_description): | |
| print((name, item_condition, category, brand_name, shipping_included, item_description)) | |
| # Preprocess Input | |
| if shipping_included: | |
| shipping_str = "Includes Shipping" | |
| else: | |
| shipping_str = "No Shipping" | |
| combined = "Item Name: " + name + \ | |
| " Description: " + item_description + \ | |
| " Condition: " + item_condition + \ | |
| " Category: " + category + \ | |
| " Brand " + brand_name + \ | |
| " Shipping: " + shipping_str | |
| inputs = tokenizer.encode_plus( | |
| combined, | |
| None, | |
| add_special_tokens=True, | |
| max_length=max_len, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| with torch.no_grad(): | |
| output = bertregressor(input_ids, attention_mask) | |
| return output.item() | |
| demo = gr.Interface( | |
| fn = predict_price, | |
| inputs = [gr.Textbox(label="Item Name"), | |
| gr.Dropdown(['Poor', 'Okay', 'Good', 'Excellent', 'Like New'], label="Item Condition", info="What condition is the item in?"), | |
| gr.Textbox(label="Category on Mercari"), | |
| gr.Textbox(label="Brand"), | |
| gr.Checkbox(label="Shipping Included"), | |
| gr.Textbox(label="Description") | |
| ], | |
| #outputs = gr.Textbox() | |
| outputs= gr.Number() | |
| ) | |
| demo.launch() |