Spaces:
Running
Running
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model import SRCNNModel, pred_SRCNN | |
| from PIL import Image | |
| title = "Super Resolution with CNN" | |
| description = """ | |
| Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br> | |
| Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br> | |
| """ | |
| article = """ | |
| <div style='margin:20px auto;'> | |
| <p>Sources:<p> | |
| <p>๐ <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p> | |
| <p>๐ฆ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p> | |
| </div> | |
| """ | |
| # load model | |
| print("Loading SRCNN model...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = SRCNNModel().to(device) | |
| model.load_state_dict(torch.load('SRCNNmodel_trained.pt',map_location=torch.device(device) )) | |
| model.eval() | |
| print("SRCNN model loaded!") | |
| # def image_grid(imgs, rows, cols): | |
| # ''' | |
| # imgs:list of PILImage | |
| # ''' | |
| # assert len(imgs) == rows*cols | |
| # w, h = imgs[0].size | |
| # grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| # grid_w, grid_h = grid.size | |
| # for i, img in enumerate(imgs): | |
| # grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| # return grid | |
| def sepia(image): | |
| # gradio open image as np array | |
| image = Image.fromarray(image,mode='RGB') | |
| out_final,image_bicubic,image = pred_SRCNN(model=model,image=image,device=device) | |
| # grid = image_grid([out_final,image_bicubic],1,2) | |
| return out_final,image_bicubic | |
| demo = gr.Interface(fn = sepia, inputs=gr.inputs.Image(label="Upload image"), outputs=[gr.outputs.Image(label="Conv net"), gr.outputs.Image(label="Bicubic interpoloation")],title=title,description = description,article = article,examples=[['LR_image.png'],['barbara.png']]) | |
| demo.launch() |