import cv2 import numpy as np import torch import time import argparse import os from ultralytics import YOLO from download_model import is_lfs_pointer, download_file def slice_image(image, slice_size=640, overlap=0.2): height, width, _ = image.shape step = int(slice_size * (1 - overlap)) slices = [] for y in range(0, height, step): for x in range(0, width, step): y_end = min(y + slice_size, height) x_end = min(x + slice_size, width) y_start = max(0, y_end - slice_size) x_start = max(0, x_end - slice_size) slice_img = image[y_start:y_end, x_start:x_end] slices.append({ 'image': slice_img, 'x_offset': x_start, 'y_offset': y_start }) if x_end == width: break if y_end == height: break return slices def run_sliced_inference(model, image, slice_size=640, overlap=0.2, conf=0.25, iou_threshold=0.45): slices = slice_image(image, slice_size, overlap) all_boxes = [] all_confs = [] all_clss = [] t_start = time.time() for s in slices: results = model.predict(s['image'], conf=conf, verbose=False) for result in results: boxes = result.boxes for box in boxes: xyxy = box.xyxy[0].cpu().numpy() global_xyxy = [ xyxy[0] + s['x_offset'], xyxy[1] + s['y_offset'], xyxy[2] + s['x_offset'], xyxy[3] + s['y_offset'] ] all_boxes.append(global_xyxy) all_confs.append(float(box.conf[0])) all_clss.append(int(box.cls[0])) t_inference = time.time() - t_start if len(all_boxes) == 0: return [], t_inference boxes_tensor = torch.tensor(all_boxes) confs_tensor = torch.tensor(all_confs) # Standard NMS (Non-Maximum Suppression) to merge detections across slices from torchvision.ops import nms keep_indices = nms(boxes_tensor, confs_tensor, iou_threshold=iou_threshold) filtered_results = [] for idx in keep_indices: idx = int(idx) filtered_results.append({ 'box': all_boxes[idx], 'conf': all_confs[idx], 'class': all_clss[idx] }) return filtered_results, t_inference def main(): parser = argparse.ArgumentParser(description="Simulate Standard vs SAHI Sliced Inference") parser.add_argument("--image", type=str, default="sample_aerial_street.jpg", help="Path to input image") parser.add_argument("--model", type=str, default="aerialEye.pt", help="Path to YOLO model") parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") parser.add_argument("--slice-size", type=int, default=640, help="SAHI slice dimension") parser.add_argument("--overlap", type=float, default=0.2, help="Overlap ratio between slices") args = parser.parse_args() # Check/download files if missing for file_path in [args.model, args.image]: if file_path in ["aerialEye.pt", "best.pt", "sample_aerial_street.jpg", "sample_drone_roundabout.jpg", "bus.jpg"] and (not os.path.exists(file_path) or is_lfs_pointer(file_path)): print(f"File '{file_path}' is missing or is an LFS pointer. Downloading...") download_file(file_path) if not os.path.exists(args.image): print(f"Error: Image '{args.image}' not found.") return print(f"Loading YOLO model '{args.model}'...") model = YOLO(args.model) img = cv2.imread(args.image) if img is None: print(f"Error: Could not read image '{args.image}'.") return h, w, _ = img.shape print(f"Loaded image: {args.image} ({w}x{h})") # 1. Run Standard Inference print("\n--- Running Standard Inference ---") t_start = time.time() std_results = model.predict(img, conf=args.conf, verbose=False) std_time = time.time() - t_start std_boxes = std_results[0].boxes print(f"Standard Detections: {len(std_boxes)}") print(f"Standard Latency: {std_time:.4f} seconds") # Save standard image std_img = std_results[0].plot() cv2.imwrite("result_standard.jpg", std_img) print("Saved standard inference output to 'result_standard.jpg'") # 2. Run Sliced Inference print("\n--- Running SAHI Sliced Inference Simulation ---") slices = slice_image(img, args.slice_size, args.overlap) print(f"Generated {len(slices)} slices of size {args.slice_size}x{args.slice_size} with {args.overlap*100}% overlap.") sahi_results, sahi_time = run_sliced_inference(model, img, args.slice_size, args.overlap, args.conf) print(f"SAHI Detections (after NMS): {len(sahi_results)}") print(f"SAHI Total Latency: {sahi_time:.4f} seconds") # Render SAHI image sahi_img = img.copy() for res in sahi_results: box = res['box'] cls_id = res['class'] conf_val = res['conf'] # Draw bounding box cv2.rectangle(sahi_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2) # Put label label = f"{model.names[cls_id]} {conf_val:.2f}" cv2.putText(sahi_img, label, (int(box[0]), int(box[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) cv2.imwrite("result_sahi.jpg", sahi_img) print("Saved SAHI inference output to 'result_sahi.jpg'") # 3. Print Comparison print("\n--- Summary Comparison ---") print(f"| Metric | Standard | SAHI (Sliced) | Difference |") print(f"|---|---|---|---|") print(f"| Objects Detected | {len(std_boxes)} | {len(sahi_results)} | {len(sahi_results) - len(std_boxes):+d} |") print(f"| Inference Latency | {std_time:.3f}s | {sahi_time:.3f}s | {sahi_time - std_time:+.3f}s |") if __name__ == "__main__": main()