| """ |
| Unified service with FastAPI (for MonitaQC) and Gradio (for testing/demo). |
| This allows multiple MonitaQC vision engines to use the API while keeping the Gradio UI accessible. |
| """ |
|
|
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
| from fastapi.responses import JSONResponse |
| import gradio as gr |
| import onnxruntime as ort |
| import numpy as np |
| import cv2 |
| from huggingface_hub import hf_hub_download |
| import os |
| from io import BytesIO |
| from typing import Optional |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI( |
| title="Industrial Defect Detection API", |
| description="ONNX-based defect detection service for MonitaQC vision engines", |
| version="1.0.0" |
| ) |
|
|
| |
| MODELS = { |
| "dental-implant": {"name": "Dental Implant", "repo": "smartfalcon-ai/Dental-Implant-Defect-Detection"}, |
| "data-matrix": {"name": "Data Matrix", "repo": "smartfalcon-ai/Data-Matrix-Defect-Detection"}, |
| "ball-pen": {"name": "Ball Pen", "repo": "smartfalcon-ai/Ball-Pen-Defect-Detection"}, |
| "knit-up": {"name": "Knit Up", "repo": "smartfalcon-ai/Knit-Up-Defect-Detection"}, |
| "knit-back": {"name": "Knit Back", "repo": "smartfalcon-ai/Knit-Back-Defect-Detection"}, |
| "jean-back": {"name": "Jean Back", "repo": "smartfalcon-ai/Jean-Back-Defect-Detection"}, |
| "jean-up": {"name": "Jean Up", "repo": "smartfalcon-ai/Jean-Up-Defect-Detection"}, |
| "tire-cord": {"name": "Tire Cord", "repo": "smartfalcon-ai/Tire-Cord-Defect-Detection"} |
| } |
|
|
| |
| EXAMPLES = [ |
| |
| ["examples/dental-implant-1.jpg", "Dental Implant", 0.25], |
| ["examples/dental-implant-2.jpg", "Dental Implant", 0.25], |
| ["examples/dental-implant-3.jpg", "Dental Implant", 0.25], |
| |
| ["examples/data-matrix-1.jpg", "Data Matrix", 0.25], |
| ["examples/data-matrix-2.jpg", "Data Matrix", 0.25], |
| ["examples/data-matrix-3.jpg", "Data Matrix", 0.25], |
| |
| ["examples/ball-pen-1.jpg", "Ball Pen", 0.25], |
| ["examples/ball-pen-2.jpg", "Ball Pen", 0.25], |
| ["examples/ball-pen-3.jpg", "Ball Pen", 0.25], |
| |
| ["examples/knit-up-1.jpg", "Knit Up", 0.25], |
| ["examples/knit-up-2.jpg", "Knit Up", 0.25], |
| ["examples/knit-up-3.jpg", "Knit Up", 0.25], |
| |
| ["examples/knit-back-1.jpg", "Knit Back", 0.25], |
| ["examples/knit-back-2.jpg", "Knit Back", 0.25], |
| ["examples/knit-back-3.jpg", "Knit Back", 0.25], |
| |
| ["examples/jean-back-1.jpg", "Jean Back", 0.25], |
| ["examples/jean-back-2.jpg", "Jean Back", 0.25], |
| ["examples/jean-back-3.jpg", "Jean Back", 0.25], |
| |
| ["examples/jean-up-1.jpg", "Jean Up", 0.25], |
| ["examples/jean-up-2.jpg", "Jean Up", 0.25], |
| ["examples/jean-up-3.jpg", "Jean Up", 0.25], |
| |
| ["examples/tire-cord-1.jpg", "Tire Cord", 0.25], |
| ["examples/tire-cord-2.jpg", "Tire Cord", 0.25], |
| ["examples/tire-cord-3.jpg", "Tire Cord", 0.25], |
| ] |
|
|
| |
| sessions = {} |
|
|
| |
| DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "data-matrix") |
|
|
| |
| IMG_SIZE = 640 |
| IOU_THRESHOLD = 0.45 |
|
|
|
|
| def get_session(model_key: str): |
| """Get or create ONNX inference session for a model.""" |
| if model_key not in sessions: |
| if model_key not in MODELS: |
| raise ValueError(f"Model '{model_key}' not found. Available: {list(MODELS.keys())}") |
|
|
| try: |
| hf_token = os.environ.get("HUGGINGFACE_TOKEN", None) |
| repo_id = MODELS[model_key]["repo"] |
| logger.info(f"Downloading model: {repo_id}") |
| model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="best.onnx", |
| token=hf_token |
| ) |
| sessions[model_key] = ort.InferenceSession( |
| model_path, |
| providers=["CPUExecutionProvider"] |
| ) |
| logger.info(f"Model '{model_key}' loaded successfully") |
| except Exception as e: |
| logger.error(f"Failed to load model '{model_key}': {e}") |
| raise |
|
|
| return sessions[model_key] |
|
|
|
|
| def preprocess(img): |
| """Preprocess image for ONNX model.""" |
| h, w = img.shape[:2] |
| img_resized = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) |
| img_resized = img_resized.astype(np.float32) / 255.0 |
| img_resized = img_resized.transpose(2, 0, 1) |
| img_resized = np.expand_dims(img_resized, 0) |
| return img_resized, w, h |
|
|
|
|
| def xywh2xyxy(x): |
| """Convert box format from xywh to xyxy.""" |
| y = np.copy(x) |
| y[:, 0] = x[:, 0] - x[:, 2] / 2 |
| y[:, 1] = x[:, 1] - x[:, 3] / 2 |
| y[:, 2] = x[:, 0] + x[:, 2] / 2 |
| y[:, 3] = x[:, 1] + x[:, 3] / 2 |
| return y |
|
|
|
|
| def non_max_suppression(preds, conf_thres=0.25, iou_thres=0.45): |
| """Apply NMS to predictions.""" |
| preds = preds[0] |
| preds = preds[preds[:, 4] > conf_thres] |
| if preds.shape[0] == 0: |
| return [] |
|
|
| boxes = xywh2xyxy(preds[:, :4]) |
| scores = preds[:, 4] |
| class_scores = preds[:, 5:] |
| cls_ids = np.argmax(class_scores, axis=1) |
| cls_conf = class_scores.max(axis=1) |
| final_scores = scores * cls_conf |
|
|
| indices = cv2.dnn.NMSBoxes( |
| bboxes=boxes.tolist(), |
| scores=final_scores.tolist(), |
| score_threshold=conf_thres, |
| nms_threshold=iou_thres |
| ) |
|
|
| if len(indices) == 0: |
| return [] |
|
|
| indices = indices.flatten() |
| output = [] |
| for idx in indices: |
| x1, y1, x2, y2 = boxes[idx] |
| output.append({ |
| "bbox": [float(x1), float(y1), float(x2), float(y2)], |
| "confidence": float(final_scores[idx]), |
| "class_id": int(cls_ids[idx]), |
| "x1": float(x1), |
| "y1": float(y1), |
| "x2": float(x2), |
| "y2": float(y2) |
| }) |
| return output |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/") |
| async def root(): |
| """API root endpoint.""" |
| return { |
| "service": "Industrial Defect Detection API", |
| "version": "1.0.0", |
| "endpoints": { |
| "api": "/docs", |
| "gradio": "/gradio" |
| }, |
| "models": list(MODELS.keys()), |
| "default_model": DEFAULT_MODEL |
| } |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint.""" |
| return {"status": "healthy", "models_loaded": len(sessions)} |
|
|
|
|
| @app.get("/models") |
| async def list_models(): |
| """List all available models.""" |
| return { |
| "models": {k: v["name"] for k, v in MODELS.items()}, |
| "loaded": list(sessions.keys()) |
| } |
|
|
|
|
| @app.post("/v1/object-detection/detect") |
| async def detect_defects( |
| image: UploadFile = File(...), |
| model: Optional[str] = Form(DEFAULT_MODEL), |
| confidence: Optional[float] = Form(0.25) |
| ): |
| """ |
| Detect defects in an uploaded image. |
| |
| Compatible with MonitaQC's YOLO inference API format. |
| |
| Args: |
| image: Image file to analyze |
| model: Model name to use (default: data-matrix) |
| confidence: Confidence threshold (default: 0.25) |
| |
| Returns: |
| JSON array of detections with bbox, confidence, and class_id |
| """ |
| try: |
| |
| contents = await image.read() |
| nparr = np.frombuffer(contents, np.uint8) |
| img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| if img_bgr is None: |
| raise HTTPException(status_code=400, detail="Invalid image file") |
|
|
| |
| session = get_session(model) |
|
|
| |
| blob, orig_w, orig_h = preprocess(img_bgr) |
|
|
| |
| preds = session.run(None, {"images": blob})[0] |
|
|
| |
| detections = non_max_suppression(preds, confidence, IOU_THRESHOLD) |
|
|
| |
| for det in detections: |
| det["bbox"][0] = det["bbox"][0] / IMG_SIZE * orig_w |
| det["bbox"][1] = det["bbox"][1] / IMG_SIZE * orig_h |
| det["bbox"][2] = det["bbox"][2] / IMG_SIZE * orig_w |
| det["bbox"][3] = det["bbox"][3] / IMG_SIZE * orig_h |
| det["x1"] = det["bbox"][0] |
| det["y1"] = det["bbox"][1] |
| det["x2"] = det["bbox"][2] |
| det["y2"] = det["bbox"][3] |
|
|
| logger.info(f"Processed image with model '{model}': {len(detections)} detections") |
|
|
| return detections |
|
|
| except Exception as e: |
| logger.error(f"Detection error: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/v1/object-detection/{model_name}/detect") |
| async def detect_defects_with_model( |
| model_name: str, |
| image: UploadFile = File(...), |
| confidence: Optional[float] = Form(0.25) |
| ): |
| """ |
| Detect defects using a specific model (path parameter). |
| |
| This endpoint is compatible with MonitaQC's current YOLO API format. |
| |
| Args: |
| model_name: Model to use (e.g., 'data-matrix', 'dental-implant') |
| image: Image file to analyze |
| confidence: Confidence threshold (default: 0.25) |
| |
| Returns: |
| JSON array of detections |
| """ |
| return await detect_defects(image, model_name, confidence) |
|
|
|
|
| |
| |
| |
|
|
| def gradio_inference(image, model_display_name, conf_threshold): |
| """Inference function for Gradio UI.""" |
| |
| model_key = None |
| for key, val in MODELS.items(): |
| if val["name"] == model_display_name: |
| model_key = key |
| break |
|
|
| if model_key is None: |
| return image |
|
|
| session = get_session(model_key) |
| img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| blob, orig_w, orig_h = preprocess(img_bgr) |
| preds = session.run(None, {"images": blob})[0] |
| detections = non_max_suppression(preds, conf_threshold, IOU_THRESHOLD) |
|
|
| for det in detections: |
| x1 = int(det["x1"] / IMG_SIZE * orig_w) |
| y1 = int(det["y1"] / IMG_SIZE * orig_h) |
| x2 = int(det["x2"] / IMG_SIZE * orig_w) |
| y2 = int(det["y2"] / IMG_SIZE * orig_h) |
| score = det["confidence"] |
| cls_id = det["class_id"] |
|
|
| label = f"{cls_id}:{score:.2f}" |
| cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) |
| cv2.putText(img_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) |
|
|
| return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
|
|
|
|
| |
| gradio_app = gr.Interface( |
| fn=gradio_inference, |
| inputs=[ |
| gr.Image(type="numpy"), |
| gr.Dropdown([v["name"] for v in MODELS.values()], label="Select Model"), |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold") |
| ], |
| outputs=gr.Image(type="numpy"), |
| title="Industrial Defect Detection - Testing Interface", |
| description=""" |
| **Testing Interface** for Industrial Defect Detection models. |
| |
| - **For Production Use:** Use the FastAPI endpoints at `/v1/object-detection/detect` |
| - **For Testing:** Use this Gradio interface to visually inspect results |
| |
| Upload an image, select a defect model, and adjust the confidence threshold. |
| You can also choose from the samples at the bottom of the page. |
| """, |
| examples=EXAMPLES, |
| examples_per_page=24, |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, gradio_app, path="/gradio") |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| port = int(os.environ.get("PORT", 8000)) |
| host = os.environ.get("HOST", "0.0.0.0") |
|
|
| logger.info(f"Starting Industrial Defect Detection Service on {host}:{port}") |
| logger.info(f" - FastAPI docs: http://{host}:{port}/docs") |
| logger.info(f" - Gradio UI: http://{host}:{port}/gradio") |
| logger.info(f"Available models: {list(MODELS.keys())}") |
| logger.info(f"Default model: {DEFAULT_MODEL}") |
|
|
| uvicorn.run(app, host=host, port=port) |
|
|