import gradio as gr
from ultralytics import YOLO
import cv2
import tempfile
import numpy as np
import os
import json
from datetime import datetime
from download_model import is_lfs_pointer, download_file
STATS_FILE = "download_stats.json"
# Initialize Model
try:
if not os.path.exists("aerialEye.pt") or is_lfs_pointer("aerialEye.pt"):
print("aerialEye.pt is missing or is an LFS pointer. Downloading from Hugging Face...")
download_file("aerialEye.pt")
model = YOLO("aerialEye.pt")
except Exception as e:
print(f"Error loading model: {e}")
model = None
def predict_image(image, conf_threshold, iou_threshold):
if model is None:
return image, "Error: Model not loaded."
results = model.predict(
source=image,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True
)
# Extract the annotated image
im_array = results[0].plot()
return im_array[..., ::-1] # Return RGB for Gradio
def predict_video(video_path, conf_threshold, iou_threshold):
if model is None:
return None
# Open the video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Create temp file for output output
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
out_path = temp_file.name
# Define codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
max_frames = 300 # Limit to 10 seconds at 30fps to prevent hanging spaces
frame_count = 0
while cap.isOpened() and frame_count < max_frames:
success, frame = cap.read()
if not success:
break
results = model.predict(
source=frame,
conf=conf_threshold,
iou=iou_threshold,
verbose=False
)
annotated_frame = results[0].plot()
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
return out_path
# Download Tracking & Snippets Logic
def track_download(filename):
try:
if os.path.exists(STATS_FILE):
with open(STATS_FILE, "r") as f:
data = json.load(f)
else:
data = {"downloads": {}, "code_views": {}, "history": []}
if "downloads" not in data:
data["downloads"] = {}
if "code_views" not in data:
data["code_views"] = {}
if "history" not in data:
data["history"] = []
for key in ["aerialEye.pt", "aerialEye.onnx", "best.pt", "best.onnx", "best_full_integer_quant.tflite"]:
if key not in data["downloads"]:
data["downloads"][key] = 0
if filename in data["downloads"]:
data["downloads"][filename] += 1
data["history"].append({
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"type": "Download",
"file": filename
})
if len(data["history"]) > 50:
data["history"] = data["history"][-50:]
with open(STATS_FILE, "w") as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"Error tracking download: {e}")
def track_code_view(env_name):
try:
if os.path.exists(STATS_FILE):
with open(STATS_FILE, "r") as f:
data = json.load(f)
else:
data = {"downloads": {}, "code_views": {}, "history": []}
if "downloads" not in data:
data["downloads"] = {}
if "code_views" not in data:
data["code_views"] = {}
if "history" not in data:
data["history"] = []
for key in ["Python (ultralytics)", "Python (onnxruntime)", "CLI (ultralytics)"]:
if key not in data["code_views"]:
data["code_views"][key] = 0
if env_name in data["code_views"]:
data["code_views"][env_name] += 1
data["history"].append({
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"type": "Code View",
"file": env_name
})
if len(data["history"]) > 50:
data["history"] = data["history"][-50:]
with open(STATS_FILE, "w") as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"Error tracking code view: {e}")
def get_stats():
try:
if os.path.exists(STATS_FILE):
with open(STATS_FILE, "r") as f:
return json.load(f)
except Exception as e:
print(f"Error loading stats: {e}")
return {
"downloads": {
"aerialEye.pt": 0, "aerialEye.onnx": 0, "best.pt": 0, "best.onnx": 0, "best_full_integer_quant.tflite": 0
},
"code_views": {
"Python (ultralytics)": 0, "Python (onnxruntime)": 0, "CLI (ultralytics)": 0
},
"history": []
}
def get_code_snippet(option):
snippets = {
"Python (ultralytics)": """from ultralytics import YOLO
# 1. Load the model weights
model = YOLO("aerialEye.pt")
# 2. Run inference on an image
results = model("sample_aerial_street.jpg")
# 3. Save or display results
results[0].show()
# results[0].save(filename="result.jpg")""",
"Python (onnxruntime)": """import cv2
import numpy as np
import onnxruntime as ort
# 1. Load ONNX model
session = ort.InferenceSession("aerialEye.onnx")
# 2. Preprocess input image (resize to 640x640, scale 0-1, CHW format)
img = cv2.imread("sample_aerial_street.jpg")
img_resized = cv2.resize(img, (640, 640))
input_data = np.transpose(img_resized, (2, 0, 1)).astype(np.float32) / 255.0
input_data = np.expand_dims(input_data, axis=0) # Add batch dimension
# 3. Run Inference
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input_data})
print("ONNX Outputs:", outputs)""",
"CLI (ultralytics)": """# Run inference via ultralytics CLI
yolo task=detect mode=predict model=aerialEye.pt source=sample_aerial_street.jpg conf=0.25"""
}
# Return code snippet and track view
track_code_view(option)
return snippets.get(option, snippets["Python (ultralytics)"])
def generate_analytics_html(stats):
downloads = stats.get("downloads", {})
history = stats.get("history", [])
code_views = stats.get("code_views", {})
total_downloads = sum(downloads.values())
total_views = sum(code_views.values())
# Calculate percentages for progress bars
progress_html = ""
for file, count in downloads.items():
pct = (count / total_downloads * 100) if total_downloads > 0 else 0
progress_html += f"""
{file}
{count} ({pct:.1f}%)
"""
code_views_html = ""
for env, count in code_views.items():
pct = (count / total_views * 100) if total_views > 0 else 0
code_views_html += f"""
{env}
{count} ({pct:.1f}%)
"""
history_rows = ""
for entry in reversed(history[-10:]): # Show last 10 entries, newest first
badge_style = "background-color: #1e1b4b; color: #c084fc;" if entry.get('type') == 'Download' else "background-color: #1e3a8a; color: #93c5fd;"
history_rows += f"""
{entry.get('timestamp')}
{entry.get('type', 'Download')}
{entry.get('file')}
"""
if not history_rows:
history_rows = """
No activity logged yet.
"""
html = f"""
Download Statistics
{progress_html}
Developer Integration Views
{code_views_html}
Recent Activity Log
Timestamp
Action
Resource
{history_rows}
"""
return html
# CSS Design
css_styles = """
body {
background-color: #0b0f19 !important;
}
.gradio-container {
background-color: #0b0f19 !important;
border: none !important;
}
.hf-container {
padding: 1rem 0;
max-width: 1200px;
margin: 0 auto;
}
.hf-header {
border-bottom: 1px solid #1f2937;
padding-bottom: 1.5rem;
margin-bottom: 1.5rem;
}
.hf-repo-title {
font-size: 1.75rem;
font-weight: 700;
display: flex;
align-items: center;
gap: 0.5rem;
color: #ffffff;
}
.hf-repo-user {
color: #9ca3af;
font-weight: 400;
}
.hf-repo-sep {
color: #4b5563;
font-weight: 300;
}
.hf-badge-container {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
margin-top: 0.75rem;
}
.hf-badge {
background-color: #1f2937;
color: #e5e7eb;
padding: 0.2rem 0.5rem;
border-radius: 6px;
font-size: 0.75rem;
font-weight: 500;
display: inline-flex;
align-items: center;
gap: 0.25rem;
border: 1px solid #374151;
}
.hf-badge-green {
background-color: rgba(22, 101, 52, 0.2) !important;
color: #4ade80 !important;
border: 1px solid rgba(34, 197, 94, 0.3) !important;
}
.hf-badge-blue {
background-color: rgba(30, 58, 138, 0.2) !important;
color: #60a5fa !important;
border: 1px solid rgba(59, 130, 246, 0.3) !important;
}
.hf-badge-purple {
background-color: rgba(88, 28, 135, 0.2) !important;
color: #c084fc !important;
border: 1px solid rgba(147, 51, 234, 0.3) !important;
}
.hf-card {
background-color: #111827 !important;
border: 1px solid #1f2937 !important;
border-radius: 12px !important;
padding: 1.5rem !important;
margin-bottom: 1.5rem !important;
}
.hf-card-title {
font-size: 1.1rem !important;
font-weight: 600 !important;
margin-bottom: 1rem !important;
color: #ffffff !important;
display: flex;
align-items: center;
gap: 0.5rem;
}
.hf-downloads-metric {
font-size: 2.5rem !important;
font-weight: 800 !important;
color: #c084fc !important;
line-height: 1.2 !important;
margin-bottom: 0.25rem !important;
}
.hf-downloads-label {
color: #9ca3af !important;
font-size: 0.875rem !important;
}
"""
def generate_downloads_html():
stats = get_stats()
total_downloads = sum(stats.get("downloads", {}).values())
return f"""
{total_downloads:,}
Downloads last month
"""
# Gradio Interface
with gr.Blocks(title="AerialEye Patrol Detection", css=css_styles) as demo:
with gr.Column(elem_classes=["hf-container"]):
# Top Header
gr.HTML("""
""")
with gr.Tabs():
# 1. Model Card
with gr.TabItem("📄 Model Card"):
with gr.Row():
# Left Column (70%)
with gr.Column(scale=7):
gr.HTML("""
Nemotron AerialEye 3.5 ASR / Patrol
Model Arch: YOLOv11-Nano
Params: 2.6M
Language: Multilingual
Domain: Aerial & Drone Imagery
AerialEye is a high-performance fine-tuned object detection model based on the YOLOv11-Nano architecture.
It is trained specifically for aerial patrolling, disaster response, and search-and-rescue (SAR) operations,
capable of detecting humans, vehicles, SOS markers, flooding, road damages, and structural cracks.
""")
gr.HTML("""
🔍 Slicing Aided Hyper Inference (SAHI) Integration
High-altitude drone imagery contains extremely small target objects. Standard models downscale high-resolution inputs,
which destroys pixel density and causes high-altitude targets to vanish.
AerialEye utilizes SAHI slicing, partition-based sliding inference, and Non-Maximum Suppression (NMS) merging
to preserve small object details.
""")
# Show comparison map
if os.path.exists("sutra_standard_vs_sahi.jpg"):
gr.Image("sutra_standard_vs_sahi.jpg", label="SAHI Slicing vs Standard Downscaled Inference Comparison", interactive=False)
else:
gr.Markdown("*(Comparison image `sutra_standard_vs_sahi.jpg` is not downloaded yet. Run the Python Downloader to retrieve it.)*")
gr.HTML("""
📊 Side-by-Side Evaluation Metrics
Metric
Standard (Downscaled)
SAHI (Sliced Window)
Delta / Improvement
Objects Detected
65
50
-15 (-23.1% False Positives)
Inference Latency
733.0 ms
293.6 ms
-439.4 ms (-59.9% Latency)
Resolution Processing
640x640 (Downscaled)
Multi-Tile Slicing (Full Scale)
Preserves pixel density
""")
# Right Column (30%)
with gr.Column(scale=3):
# Downloads Sidebar card
with gr.Group(elem_classes=["hf-card"]):
downloads_sidebar_metric = gr.HTML(value=generate_downloads_html())
download_format = gr.Dropdown(
choices=[
"aerialEye.pt (PyTorch - 6.0 MB)",
"aerialEye.onnx (ONNX - 11.7 MB)",
"best.pt (PyTorch - 5.5 MB)",
"best.onnx (ONNX - 10.6 MB)",
"best_full_integer_quant.tflite (TFLite - 2.9 MB)"
],
value="aerialEye.pt (PyTorch - 6.0 MB)",
label="Select Weights Format"
)
download_btn = gr.Button("⬇️ Download Weights", variant="primary")
download_file_output = gr.File(label="Downloaded Model File", visible=False)
# Use this model Sidebar card
with gr.Group(elem_classes=["hf-card"]):
gr.HTML("""
🚀 Use this model
""")
use_dropdown = gr.Dropdown(
choices=["Python (ultralytics)", "Python (onnxruntime)", "CLI (ultralytics)"],
value="Python (ultralytics)",
label="Integration Method"
)
code_output = gr.Code(
value=get_code_snippet("Python (ultralytics)"),
language="python",
interactive=False
)
# Model details metadata card
with gr.Group(elem_classes=["hf-card"]):
gr.HTML("""
📋 Model Attributes
Model Class Support: 6 Classes
human
sos
vehicle
flood
road_damage
crack
Dataset Size: 6,327 images
License: Apache 2.0
""")
# 2. Interactive Testing Playground
with gr.TabItem("🎮 Model Playground"):
gr.HTML("""
🎮 Interactive Test Playground
Upload and test images or videos in real-time to verify AerialEye detections.
""")
with gr.Tabs():
with gr.TabItem("Images"):
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image")
conf_slider_img = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold")
iou_slider_img = gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IoU Threshold")
submit_btn_img = gr.Button("Detect Image", variant="primary")
with gr.Column():
output_image = gr.Image(type="numpy", label="Detections")
with gr.Row():
gr.Examples(
examples=[["sample_train_1.jpg", 0.25, 0.45], ["sample_train_2.jpg", 0.25, 0.45], ["sample_train_3.jpg", 0.25, 0.45]],
inputs=[input_image, conf_slider_img, iou_slider_img],
label="Original Training Distribution Samples"
)
submit_btn_img.click(
fn=predict_image,
inputs=[input_image, conf_slider_img, iou_slider_img],
outputs=[output_image]
)
with gr.TabItem("Videos"):
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video (First 10s will be processed)")
conf_slider_vid = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold")
iou_slider_vid = gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IoU Threshold")
submit_btn_vid = gr.Button("Process Video", variant="primary")
with gr.Column():
output_video = gr.Video(label="Processed Detections")
submit_btn_vid.click(
fn=predict_video,
inputs=[input_video, conf_slider_vid, iou_slider_vid],
outputs=[output_video]
)
# 3. Analytics Dashboard
with gr.TabItem("📊 Analytics Dashboard"):
gr.HTML("""
📊 Model Metrics & Usage Analytics
Real-time tracking of weights downloads and developer code integrations from the model card.
""")
stats = get_stats()
analytics_board = gr.HTML(value=generate_analytics_html(stats))
gr.HTML("""
🚁 Running on CPU/GPU depending on environment available | AerialEye Patrol System
""")
# Event wiring
def handle_download(selected_option):
mapping = {
"aerialEye.pt (PyTorch - 6.0 MB)": "aerialEye.pt",
"aerialEye.onnx (ONNX - 11.7 MB)": "aerialEye.onnx",
"best.pt (PyTorch - 5.5 MB)": "best.pt",
"best.onnx (ONNX - 10.6 MB)": "best.onnx",
"best_full_integer_quant.tflite (TFLite - 2.9 MB)": "best_full_integer_quant.tflite"
}
filename = mapping.get(selected_option, "aerialEye.pt")
# Download check
if not os.path.exists(filename) or is_lfs_pointer(filename):
download_file(filename)
# Track the download in stats
track_download(filename)
# Get updated statistics
stats = get_stats()
# Update sidebar metric and analytics dashboard
new_metric_html = generate_downloads_html()
new_analytics_html = generate_analytics_html(stats)
return gr.update(value=filename, visible=True), new_metric_html, new_analytics_html
def handle_use_change(option):
code_val = get_code_snippet(option)
lang = "bash" if "CLI" in option else "python"
stats = get_stats()
new_analytics_html = generate_analytics_html(stats)
return gr.update(value=code_val, language=lang), new_analytics_html
download_btn.click(
fn=handle_download,
inputs=[download_format],
outputs=[download_file_output, downloads_sidebar_metric, analytics_board]
)
use_dropdown.change(
fn=handle_use_change,
inputs=[use_dropdown],
outputs=[code_output, analytics_board]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)