import os import tempfile from typing import Optional import gradio as gr import torch from PIL import Image try: import spaces except Exception: spaces = None from diffusers import DiffusionPipeline from diffusers.utils import export_to_video SPACE_ID = os.getenv("SPACE_ID", "").lower() IS_VIDEO_SPACE = any(k in SPACE_ID for k in ["hunyuanvideo", "wan-2-1"]) IMAGE_MODEL_ID = os.getenv("IMAGE_MODEL_ID", "runwayml/stable-diffusion-v1-5") VIDEO_MODEL_ID = os.getenv("VIDEO_MODEL_ID", "damo-vilab/text-to-video-ms-1.7b") # Known ungated defaults per space: avoids GatedRepoError on HF Spaces without manual model-license acceptance. SPACE_IMAGE_DEFAULTS = { "fhdr-uncensored": "SG161222/Realistic_Vision_V6.0_B1_noVAE", "z-image-turbo": "stabilityai/sdxl-turbo", } FALLBACK_IMAGE_MODELS = [ IMAGE_MODEL_ID, SPACE_IMAGE_DEFAULTS.get(SPACE_ID.split("/")[-1], ""), "runwayml/stable-diffusion-v1-5", ] _image_pipe: Optional[DiffusionPipeline] = None _video_pipe: Optional[DiffusionPipeline] = None def _device_dtype(): if torch.cuda.is_available(): if torch.cuda.get_device_properties(0).major >= 8: return "cuda", torch.bfloat16 return "cuda", torch.float16 return "cpu", torch.float32 def _load_image_pipe() -> DiffusionPipeline: global _image_pipe if _image_pipe is None: device, dtype = _device_dtype() last_error = None for model_id in [m for m in FALLBACK_IMAGE_MODELS if m]: try: # Try to disable safety checker when supported. _image_pipe = DiffusionPipeline.from_pretrained( model_id, torch_dtype=dtype, safety_checker=None, requires_safety_checker=False, ) break except TypeError: _image_pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype) break except Exception as e: last_error = e _image_pipe = None if _image_pipe is None: raise RuntimeError(f"Unable to load image model from {FALLBACK_IMAGE_MODELS}: {last_error}") # Explicit runtime bypass for diffusion pipelines exposing an NSFW safety checker. if hasattr(_image_pipe, "safety_checker"): _image_pipe.safety_checker = None if hasattr(_image_pipe, "requires_safety_checker"): _image_pipe.requires_safety_checker = False if device == "cuda": _image_pipe.enable_model_cpu_offload() else: _image_pipe.to("cpu") return _image_pipe def _load_video_pipe() -> DiffusionPipeline: global _video_pipe if _video_pipe is None: device, _ = _device_dtype() dtype = torch.float16 if device == "cuda" else torch.float32 _video_pipe = DiffusionPipeline.from_pretrained(VIDEO_MODEL_ID, torch_dtype=dtype) if device == "cuda": _video_pipe.enable_model_cpu_offload() else: _video_pipe.to("cpu") return _video_pipe def _gpu_decorator(seconds: int): if spaces is not None: return spaces.GPU(duration=seconds) def _wrap(fn): return fn return _wrap @_gpu_decorator(120) def generate_image(prompt: str, steps: int, guidance_scale: float, seed: int): prompt = (prompt or "").strip() or "A cinematic photo of a woman on a beach at sunset" pipe = _load_image_pipe() gen = torch.Generator(device="cpu").manual_seed(int(seed)) default_size = 512 if "v1-5" in str(getattr(pipe, "name_or_path", "")).lower() else 1024 width = int(os.getenv("IMAGE_WIDTH", default_size)) height = int(os.getenv("IMAGE_HEIGHT", default_size)) image: Image.Image = pipe( prompt=prompt, num_inference_steps=int(steps), guidance_scale=float(guidance_scale), generator=gen, width=width, height=height, ).images[0] return image @_gpu_decorator(240) def generate_video(prompt: str, steps: int, fps: int, num_frames: int, seed: int): prompt = (prompt or "").strip() or "A woman walking on a sunny beach, cinematic shot" pipe = _load_video_pipe() gen = torch.Generator(device="cpu").manual_seed(int(seed)) result = pipe( prompt, num_inference_steps=int(steps), num_frames=int(num_frames), generator=gen, ) frames = result.frames[0] out_path = os.path.join(tempfile.gettempdir(), "generated_video.mp4") export_to_video(frames, out_path, fps=int(fps)) return out_path def build_ui(): title = os.getenv("SPACE_TITLE", SPACE_ID.split("/")[-1].replace("-", " ").title() or "AI Generator") if IS_VIDEO_SPACE: with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(f"## {title} — AI Video Generation") prompt = gr.Textbox(label="Prompt", value="A woman walking on a sunny beach, cinematic shot") with gr.Row(): steps = gr.Slider(8, 40, value=20, step=1, label="Inference steps") num_frames = gr.Slider(8, 32, value=16, step=1, label="Frames") fps = gr.Slider(4, 16, value=8, step=1, label="FPS") seed = gr.Number(value=42, precision=0, label="Seed") out = gr.Video(label="Generated video") btn = gr.Button("Generate") btn.click(generate_video, [prompt, steps, fps, num_frames, seed], [out]) return demo with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(f"## {title} — AI Image Generation") prompt = gr.Textbox(label="Prompt", value="A cinematic photo of a woman on a beach at sunset") with gr.Row(): steps = gr.Slider(4, 40, value=20, step=1, label="Inference steps") guidance = gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance scale") seed = gr.Number(value=42, precision=0, label="Seed") out = gr.Image(type="pil", label="Generated image") btn = gr.Button("Generate") btn.click(generate_image, [prompt, steps, guidance, seed], [out]) return demo demo = build_ui() if __name__ == "__main__": demo.launch()