File size: 4,058 Bytes
2a08a94
aa3a1ee
 
6ef2343
 
 
 
aa3a1ee
 
 
 
2a08a94
6ef2343
 
2a08a94
6ef2343
2a08a94
6ef2343
 
 
2a08a94
6ef2343
 
 
 
 
 
 
 
 
 
 
2a08a94
6ef2343
2a08a94
6ef2343
 
 
 
 
 
 
 
 
2a08a94
 
 
 
 
 
 
6ef2343
 
2a08a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef2343
2a08a94
 
 
6ef2343
 
2a08a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef2343
2a08a94
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import os
from huggingface_hub import login
import requests
import base64
from PIL import Image
import io

# Login to Hugging Face
if os.getenv("HF_TOKEN"):
    login(token=os.getenv("HF_TOKEN"))

def generate_image_mlx(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed):
    """Generate image using MLX via Hugging Face Inference API"""
    try:
        print(f"Generating image with prompt: {prompt}")
        
        # Use Hugging Face Inference API with MLX
        api_url = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
        headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
        
        payload = {
            "inputs": prompt,
            "parameters": {
                "negative_prompt": negative_prompt,
                "num_inference_steps": num_inference_steps,
                "guidance_scale": guidance_scale,
                "width": width,
                "height": height,
                "seed": seed if seed else None
            }
        }
        
        response = requests.post(api_url, headers=headers, json=payload)
        
        if response.status_code == 200:
            # Convert response to PIL Image
            image = Image.open(io.BytesIO(response.content))
            print("Image generated successfully with MLX!")
            return image
        else:
            print(f"API request failed: {response.status_code} - {response.text}")
            return None
            
    except Exception as e:
        print(f"Error generating image: {str(e)}")
        return None

# Create Gradio interface
with gr.Blocks(title="Silicon Photonics Image Generator") as demo:
    gr.Markdown("# 🔬 Silicon Photonics Image Generator")
    gr.Markdown("Generate technical diagrams and visualizations for silicon photonics using MLX-powered inference.")
    gr.Markdown("**🚀 Powered by MLX for ultra-fast Apple Silicon performance!**")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Prompt",
                value="A detailed technical diagram of a silicon waveguide showing light propagation",
                lines=3
            )
            negative_prompt = gr.Textbox(
                label="Negative Prompt",
                value="blurry, low quality, cartoon, artistic, hand-drawn, sketchy, unprofessional, low resolution",
                lines=2
            )
            
            with gr.Row():
                num_inference_steps = gr.Slider(1, 50, value=20, label="Inference Steps")
                guidance_scale = gr.Slider(1.0, 20.0, value=7.0, label="Guidance Scale")
            
            with gr.Row():
                width = gr.Slider(512, 1024, value=1024, step=64, label="Width")
                height = gr.Slider(512, 1024, value=1024, step=64, label="Height")
                seed = gr.Number(value=42, label="Seed (0 for random)")
            
            generate_btn = gr.Button("🎨 Generate Image", variant="primary")
        
        with gr.Column():
            output_image = gr.Image(label="Generated Image", type="pil")
    
    # Example prompts
    gr.Markdown("### 💡 Example Prompts")
    examples = [
        "A detailed technical diagram of a silicon waveguide showing light propagation",
        "Cross-section view of a silicon photonic integrated circuit with waveguides and couplers",
        "3D visualization of a silicon ring resonator for optical filtering",
        "Schematic diagram of a silicon photonic modulator with electrodes",
        "Top view of a silicon photonic array with multiple waveguides and couplers"
    ]
    
    gr.Examples(
        examples=[[ex] for ex in examples],
        inputs=[prompt]
    )
    
    # Event handlers
    generate_btn.click(
        fn=generate_image_mlx,
        inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed],
        outputs=[output_image]
    )

if __name__ == "__main__":
    demo.launch()