import gradio as gr import torch from diffusers import StableDiffusionXLPipeline from peft import PeftModel import base64 import io from PIL import Image # Global variables pipe = None def load_model(): """Load the model (cached after first load)""" global pipe if pipe is not None: return pipe try: # Load base model pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True ) # Load LoRA adapter pipe.unet = PeftModel.from_pretrained(pipe.unet, "silicon-photonics-lora") # Move to device device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipe.to(device) return pipe except Exception as e: print(f"Error loading model: {str(e)}") raise e def generate_image(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed): """Generate image using the silicon photonics LoRA model""" try: # Load model if not already loaded pipe = load_model() # Generate image with torch.no_grad(): generator = torch.Generator().manual_seed(seed) if seed else None image = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, generator=generator ).images[0] return image except Exception as e: print(f"Error generating image: {str(e)}") return None # Create Gradio interface with gr.Blocks(title="Silicon Photonics Image Generator") as demo: gr.Markdown("# 🔬 Silicon Photonics Image Generator") gr.Markdown("Generate technical diagrams and visualizations for silicon photonics using our fine-tuned LoRA model.") with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", value="A detailed technical diagram of a silicon waveguide showing light propagation", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt", value="blurry, low quality, cartoon, artistic, hand-drawn, sketchy, unprofessional, low resolution", lines=2 ) with gr.Row(): num_inference_steps = gr.Slider(1, 50, value=20, label="Inference Steps") guidance_scale = gr.Slider(1.0, 20.0, value=7.0, label="Guidance Scale") with gr.Row(): width = gr.Slider(512, 1024, value=1024, step=64, label="Width") height = gr.Slider(512, 1024, value=1024, step=64, label="Height") seed = gr.Number(value=42, label="Seed (0 for random)") generate_btn = gr.Button("🎨 Generate Image", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") # Example prompts gr.Markdown("### 💡 Example Prompts") examples = [ "A detailed technical diagram of a silicon waveguide showing light propagation", "Cross-section view of a silicon photonic integrated circuit with waveguides and couplers", "3D visualization of a silicon ring resonator for optical filtering", "Schematic diagram of a silicon photonic modulator with electrodes", "Top view of a silicon photonic array with multiple waveguides and couplers" ] gr.Examples( examples=[[ex] for ex in examples], inputs=[prompt] ) # Event handlers generate_btn.click( fn=generate_image, inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed], outputs=[output_image] ) if __name__ == "__main__": demo.launch()