| from typing import Dict, List, Any
|
| import torch
|
| import base64
|
| import io
|
| from PIL import Image
|
| from tryon_core import TryOnEngine
|
| from api_utils import prepare_image_for_processing, image_to_base64
|
|
|
| class EndpointHandler:
|
| def __init__(self, path=""):
|
|
|
|
|
| print("Initializing IDM-VTON Handler...")
|
| self.engine = TryOnEngine(load_mode="4bit", enable_cpu_offload=False, fixed_vae=True)
|
|
|
|
|
|
|
|
|
|
|
| self.engine.load_models()
|
| self.engine.load_processing_models()
|
| print("Handler Initialized!")
|
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| """
|
| Args:
|
| data (:obj:):
|
| includes the input data and the parameters for the inference.
|
| """
|
|
|
| inputs = data.pop("inputs", data)
|
| human_img_b64 = inputs.get("human_image")
|
| garment_img_b64 = inputs.get("garment_image")
|
| description = inputs.get("garment_description", "a photo of a garment")
|
| category = inputs.get("category", "upper_body")
|
|
|
|
|
| human_img = Image.open(io.BytesIO(base64.b64decode(human_img_b64)))
|
| garment_img = Image.open(io.BytesIO(base64.b64decode(garment_img_b64)))
|
|
|
|
|
| human_img = prepare_image_for_processing(human_img)
|
| garment_img = prepare_image_for_processing(garment_img)
|
|
|
|
|
| generated_images, masked_image = self.engine.generate(
|
| human_img=human_img,
|
| garment_img=garment_img,
|
| garment_description=description,
|
| category=category,
|
| use_auto_mask=True,
|
| use_auto_crop=True,
|
| denoise_steps=30,
|
| seed=42,
|
| num_images=1
|
| )
|
|
|
|
|
| return [{
|
| "generated_image": image_to_base64(generated_images[0]),
|
| "masked_image": image_to_base64(masked_image)
|
| }]
|
|
|