gokaygokay's picture
Update app.py
6b9f789 verified
Raw
History Blame Contribute Delete
2.53 kB
import spaces
import gradio as gr
import gradio_client.utils as _gc_utils
_original_get_type = _gc_utils.get_type
_original_json_schema = _gc_utils._json_schema_to_python_type
def _safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
return _original_get_type(schema)
def _safe_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _original_json_schema(schema, defs)
_gc_utils.get_type = _safe_get_type
_gc_utils._json_schema_to_python_type = _safe_json_schema_to_python_type
from PIL import Image
from transparent_background import Remover
import numpy as np
# Lazy-init on the GPU worker
remover = None
def _get_mask(img: Image.Image) -> Image.Image:
"""Run the model and return a clean grayscale mask, regardless of
whether the library hands us a PIL image, float numpy array, or uint8 array."""
out = remover.process(img, type='map')
if isinstance(out, Image.Image):
return out.convert('L')
arr = np.asarray(out)
if arr.dtype != np.uint8:
arr = (arr * 255).clip(0, 255).astype(np.uint8)
if arr.ndim == 3:
arr = arr[..., 0]
return Image.fromarray(arr, mode='L')
@spaces.GPU
def process_image(input_image, output_type):
global remover
if remover is None:
remover = Remover(jit=False)
input_image = input_image.convert('RGB')
mask = _get_mask(input_image)
if output_type == "Mask only":
return mask
# Compose RGBA ourselves: original pixels + our mask as alpha.
rgba = input_image.convert('RGBA')
rgba.putalpha(mask)
return rgba
description = """<h1 align="center">InSPyReNet Background Remover</h1>
<p><center>
<a href="https://github.com/plemeri/InSPyReNet" target="_blank">[Github]</a>
<a href="https://dualview.ai" target="_blank">[Compare Results]</a>
</center></p>
"""
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image", height=512),
gr.Radio(["Default", "Mask only"], label="Output Type", value="Default")
],
outputs=gr.Image(
type="pil",
label="Output Image",
height=512,
image_mode="RGBA",
format="png",
),
description=description,
theme='bethecloud/storj_theme',
examples=[
["1.png", "Default"],
["2.png", "Default"],
["3.jfif", "Default"],
["4.webp", "Default"]
],
cache_examples=True
)
if __name__ == "__main__":
iface.launch()