| import spaces |
| import gradio as gr |
| import torch |
| import numpy as np |
| from PIL import Image |
| from transformers import Sam3Processor, Sam3Model |
| import requests |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) |
| processor = Sam3Processor.from_pretrained("facebook/sam3") |
|
|
| @spaces.GPU() |
| def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): |
| """ |
| Perform promptable concept segmentation using SAM3. |
| Returns format compatible with gr.AnnotatedImage: (image, [(mask, label), ...]) |
| """ |
| if image is None: |
| return None, "❌ Please upload an image." |
| |
| if not text.strip(): |
| return (image, []), "❌ Please enter a text prompt." |
| |
| try: |
| inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) |
| |
| for key in inputs: |
| if inputs[key].dtype == torch.float32: |
| inputs[key] = inputs[key].to(model.dtype) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| results = processor.post_process_instance_segmentation( |
| outputs, |
| threshold=threshold, |
| mask_threshold=mask_threshold, |
| target_sizes=inputs.get("original_sizes").tolist() |
| )[0] |
| |
| n_masks = len(results['masks']) |
| if n_masks == 0: |
| return (image, []), f"❌ No objects found matching '{text}' (try adjusting thresholds)." |
| |
| |
| |
| annotations = [] |
| for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])): |
| |
| mask_np = mask.cpu().numpy().astype(np.float32) |
| label = f"{text} #{i+1} ({score:.2f})" |
| annotations.append((mask_np, label)) |
| |
| scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) |
| info = f"✅ Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}" |
| |
| |
| return (image, annotations), info |
| |
| except Exception as e: |
| return (image, []), f"❌ Error during segmentation: {str(e)}" |
|
|
| def clear_all(): |
| """Clear all inputs and outputs""" |
| return None, "", None, 0.5, 0.5, "📝 Enter a prompt and click **Segment** to start." |
|
|
| def segment_example(image_path: str, prompt: str): |
| """Handle example clicks""" |
| if image_path.startswith("http"): |
| image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB") |
| else: |
| image = Image.open(image_path).convert("RGB") |
| return segment(image, prompt, 0.5, 0.5) |
|
|
| |
| with gr.Blocks( |
| theme=gr.themes.Soft(), |
| title="SAM3 - Promptable Concept Segmentation", |
| css=".gradio-container {max-width: 1400px !important;}" |
| ) as demo: |
| gr.Markdown( |
| """ |
| # SAM3 - Promptable Concept Segmentation (PCS) |
| |
| **SAM3** performs zero-shot instance segmentation using natural language prompts. |
| Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks. |
| |
| Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) |
| """ |
| ) |
| |
| gr.Markdown("### Inputs") |
| with gr.Row(variant="panel"): |
| image_input = gr.Image( |
| label="Input Image", |
| type="pil", |
| height=400, |
| ) |
| |
| image_output = gr.AnnotatedImage( |
| label="Output (Segmented Image)", |
| height=400, |
| show_legend=True, |
| ) |
| |
| with gr.Row(): |
| text_input = gr.Textbox( |
| label="Text Prompt", |
| placeholder="e.g., person, ear, cat, bicycle...", |
| scale=3 |
| ) |
| clear_btn = gr.Button("🔍 Clear", size="sm", variant="secondary") |
| |
| with gr.Row(): |
| thresh_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.5, |
| step=0.01, |
| label="Detection Threshold", |
| info="Higher = fewer detections" |
| ) |
| mask_thresh_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.5, |
| step=0.01, |
| label="Mask Threshold", |
| info="Higher = sharper masks" |
| ) |
| |
| info_output = gr.Markdown( |
| value="📝 Enter a prompt and click **Segment** to start.", |
| label="Info / Results" |
| ) |
| |
| segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg") |
| |
| gr.Examples( |
| examples=[ |
| ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"], |
| ], |
| inputs=[image_input, text_input], |
| outputs=[image_output, info_output], |
| fn=segment_example, |
| cache_examples=False, |
| ) |
| |
| clear_btn.click( |
| fn=clear_all, |
| outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output] |
| ) |
| |
| segment_btn.click( |
| fn=segment, |
| inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], |
| outputs=[image_output, info_output] |
| ) |
| |
| gr.Markdown( |
| """ |
| ### Notes |
| - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) |
| - Click on segments in the output to see labels |
| - GPU recommended for faster inference |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) |