Spaces:
Running on Zero
Running on Zero
| import os | |
| import tempfile | |
| from typing import Optional | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| try: | |
| import spaces | |
| except Exception: | |
| spaces = None | |
| from diffusers import DiffusionPipeline | |
| from diffusers.utils import export_to_video | |
| SPACE_ID = os.getenv("SPACE_ID", "").lower() | |
| IS_VIDEO_SPACE = any(k in SPACE_ID for k in ["hunyuanvideo", "wan-2-1"]) | |
| IMAGE_MODEL_ID = os.getenv("IMAGE_MODEL_ID", "runwayml/stable-diffusion-v1-5") | |
| VIDEO_MODEL_ID = os.getenv("VIDEO_MODEL_ID", "damo-vilab/text-to-video-ms-1.7b") | |
| # Known ungated defaults per space: avoids GatedRepoError on HF Spaces without manual model-license acceptance. | |
| SPACE_IMAGE_DEFAULTS = { | |
| "fhdr-uncensored": "SG161222/Realistic_Vision_V6.0_B1_noVAE", | |
| "z-image-turbo": "stabilityai/sdxl-turbo", | |
| } | |
| FALLBACK_IMAGE_MODELS = [ | |
| IMAGE_MODEL_ID, | |
| SPACE_IMAGE_DEFAULTS.get(SPACE_ID.split("/")[-1], ""), | |
| "runwayml/stable-diffusion-v1-5", | |
| ] | |
| _image_pipe: Optional[DiffusionPipeline] = None | |
| _video_pipe: Optional[DiffusionPipeline] = None | |
| def _device_dtype(): | |
| if torch.cuda.is_available(): | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| return "cuda", torch.bfloat16 | |
| return "cuda", torch.float16 | |
| return "cpu", torch.float32 | |
| def _load_image_pipe() -> DiffusionPipeline: | |
| global _image_pipe | |
| if _image_pipe is None: | |
| device, dtype = _device_dtype() | |
| last_error = None | |
| for model_id in [m for m in FALLBACK_IMAGE_MODELS if m]: | |
| try: | |
| # Try to disable safety checker when supported. | |
| _image_pipe = DiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| break | |
| except TypeError: | |
| _image_pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype) | |
| break | |
| except Exception as e: | |
| last_error = e | |
| _image_pipe = None | |
| if _image_pipe is None: | |
| raise RuntimeError(f"Unable to load image model from {FALLBACK_IMAGE_MODELS}: {last_error}") | |
| # Explicit runtime bypass for diffusion pipelines exposing an NSFW safety checker. | |
| if hasattr(_image_pipe, "safety_checker"): | |
| _image_pipe.safety_checker = None | |
| if hasattr(_image_pipe, "requires_safety_checker"): | |
| _image_pipe.requires_safety_checker = False | |
| if device == "cuda": | |
| _image_pipe.enable_model_cpu_offload() | |
| else: | |
| _image_pipe.to("cpu") | |
| return _image_pipe | |
| def _load_video_pipe() -> DiffusionPipeline: | |
| global _video_pipe | |
| if _video_pipe is None: | |
| device, _ = _device_dtype() | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| _video_pipe = DiffusionPipeline.from_pretrained(VIDEO_MODEL_ID, torch_dtype=dtype) | |
| if device == "cuda": | |
| _video_pipe.enable_model_cpu_offload() | |
| else: | |
| _video_pipe.to("cpu") | |
| return _video_pipe | |
| def _gpu_decorator(seconds: int): | |
| if spaces is not None: | |
| return spaces.GPU(duration=seconds) | |
| def _wrap(fn): | |
| return fn | |
| return _wrap | |
| def generate_image(prompt: str, steps: int, guidance_scale: float, seed: int): | |
| prompt = (prompt or "").strip() or "A cinematic photo of a woman on a beach at sunset" | |
| pipe = _load_image_pipe() | |
| gen = torch.Generator(device="cpu").manual_seed(int(seed)) | |
| default_size = 512 if "v1-5" in str(getattr(pipe, "name_or_path", "")).lower() else 1024 | |
| width = int(os.getenv("IMAGE_WIDTH", default_size)) | |
| height = int(os.getenv("IMAGE_HEIGHT", default_size)) | |
| image: Image.Image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| generator=gen, | |
| width=width, | |
| height=height, | |
| ).images[0] | |
| return image | |
| def generate_video(prompt: str, steps: int, fps: int, num_frames: int, seed: int): | |
| prompt = (prompt or "").strip() or "A woman walking on a sunny beach, cinematic shot" | |
| pipe = _load_video_pipe() | |
| gen = torch.Generator(device="cpu").manual_seed(int(seed)) | |
| result = pipe( | |
| prompt, | |
| num_inference_steps=int(steps), | |
| num_frames=int(num_frames), | |
| generator=gen, | |
| ) | |
| frames = result.frames[0] | |
| out_path = os.path.join(tempfile.gettempdir(), "generated_video.mp4") | |
| export_to_video(frames, out_path, fps=int(fps)) | |
| return out_path | |
| def build_ui(): | |
| title = os.getenv("SPACE_TITLE", SPACE_ID.split("/")[-1].replace("-", " ").title() or "AI Generator") | |
| if IS_VIDEO_SPACE: | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"## {title} — AI Video Generation") | |
| prompt = gr.Textbox(label="Prompt", value="A woman walking on a sunny beach, cinematic shot") | |
| with gr.Row(): | |
| steps = gr.Slider(8, 40, value=20, step=1, label="Inference steps") | |
| num_frames = gr.Slider(8, 32, value=16, step=1, label="Frames") | |
| fps = gr.Slider(4, 16, value=8, step=1, label="FPS") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| out = gr.Video(label="Generated video") | |
| btn = gr.Button("Generate") | |
| btn.click(generate_video, [prompt, steps, fps, num_frames, seed], [out]) | |
| return demo | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"## {title} — AI Image Generation") | |
| prompt = gr.Textbox(label="Prompt", value="A cinematic photo of a woman on a beach at sunset") | |
| with gr.Row(): | |
| steps = gr.Slider(4, 40, value=20, step=1, label="Inference steps") | |
| guidance = gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance scale") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| out = gr.Image(type="pil", label="Generated image") | |
| btn = gr.Button("Generate") | |
| btn.click(generate_image, [prompt, steps, guidance, seed], [out]) | |
| return demo | |
| demo = build_ui() | |
| if __name__ == "__main__": | |
| demo.launch() | |