asamasach's picture
add examples
15c850d
Raw
History Blame
5.21 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
import os
# Private models
MODELS = {
"Dental Implant": "smartfalcon-ai/Dental-Implant-Defect-Detection",
"Data Matrix": "smartfalcon-ai/Data-Matrix-Defect-Detection",
"Ball Pen": "smartfalcon-ai/Ball-Pen-Defect-Detection",
"Knit Up": "smartfalcon-ai/Knit-Up-Defect-Detection",
"Knit Back": "smartfalcon-ai/Knit-Back-Defect-Detection",
"Jean Back": "smartfalcon-ai/Jean-Back-Defect-Detection",
"Jean Up": "smartfalcon-ai/Jean-Up-Defect-Detection",
"Tire Cord": "smartfalcon-ai/Tire-Cord-Defect-Detection"
}
# Example images for each model (3 examples per task)
EXAMPLES = [
# Dental Implant
["examples/dental-implant-1.jpg", "Dental Implant", 0.25],
["examples/dental-implant-2.jpg", "Dental Implant", 0.25],
["examples/dental-implant-3.jpg", "Dental Implant", 0.25],
# Data Matrix
["examples/data-matrix-1.jpg", "Data Matrix", 0.25],
["examples/data-matrix-2.jpg", "Data Matrix", 0.25],
["examples/data-matrix-3.jpg", "Data Matrix", 0.25],
# Ball Pen
["examples/ball-pen-1.jpg", "Ball Pen", 0.25],
["examples/ball-pen-2.jpg", "Ball Pen", 0.25],
["examples/ball-pen-3.jpg", "Ball Pen", 0.25],
# Knit Up
["examples/knit-up-1.jpg", "Knit Up", 0.25],
["examples/knit-up-2.jpg", "Knit Up", 0.25],
["examples/knit-up-3.jpg", "Knit Up", 0.25],
# Knit Back
["examples/knit-back-1.jpg", "Knit Back", 0.25],
["examples/knit-back-2.jpg", "Knit Back", 0.25],
["examples/knit-back-3.jpg", "Knit Back", 0.25],
# Jean Back
["examples/jean-back-1.jpg", "Jean Back", 0.25],
["examples/jean-back-2.jpg", "Jean Back", 0.25],
["examples/jean-back-3.jpg", "Jean Back", 0.25],
# Jean Up
["examples/jean-up-1.jpg", "Jean Up", 0.25],
["examples/jean-up-2.jpg", "Jean Up", 0.25],
["examples/jean-up-3.jpg", "Jean Up", 0.25],
# Tire Cord
["examples/tire-cord-1.jpg", "Tire Cord", 0.25],
["examples/tire-cord-2.jpg", "Tire Cord", 0.25],
["examples/tire-cord-3.jpg", "Tire Cord", 0.25],
]
sessions = {}
def get_session(model_name):
if model_name not in sessions:
hf_token = os.environ.get("HUGGINGFACE_TOKEN", None)
model_path = hf_hub_download(
repo_id=MODELS[model_name],
filename="best.onnx",
token=hf_token
)
sessions[model_name] = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
return sessions[model_name]
IMG_SIZE = 640
IOU_THRESHOLD = 0.45
def preprocess(img):
h, w = img.shape[:2]
img_resized = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
img_resized = img_resized.astype(np.float32) / 255.0
img_resized = img_resized.transpose(2, 0, 1)
img_resized = np.expand_dims(img_resized, 0)
return img_resized, w, h
def xywh2xyxy(x):
y = np.copy(x)
y[:,0] = x[:,0] - x[:,2]/2
y[:,1] = x[:,1] - x[:,3]/2
y[:,2] = x[:,0] + x[:,2]/2
y[:,3] = x[:,1] + x[:,3]/2
return y
def non_max_suppression(preds, conf_thres=0.25, iou_thres=0.45):
preds = preds[0]
preds = preds[preds[:,4] > conf_thres]
if preds.shape[0]==0:
return []
boxes = xywh2xyxy(preds[:, :4])
scores = preds[:,4]
class_scores = preds[:,5:]
cls_ids = np.argmax(class_scores, axis=1)
cls_conf = class_scores.max(axis=1)
final_scores = scores * cls_conf
indices = cv2.dnn.NMSBoxes(
bboxes=boxes.tolist(),
scores=final_scores.tolist(),
score_threshold=conf_thres,
nms_threshold=iou_thres
)
if len(indices)==0:
return []
indices = indices.flatten()
output=[]
for idx in indices:
x1, y1, x2, y2 = boxes[idx]
output.append([x1, y1, x2, y2, final_scores[idx], cls_ids[idx]])
return output
def infer(image, model_name, conf_threshold):
session = get_session(model_name)
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
blob, orig_w, orig_h = preprocess(img_bgr)
preds = session.run(None, {"images": blob})[0]
detections = non_max_suppression(preds, conf_threshold, IOU_THRESHOLD)
for det in detections:
x1, y1, x2, y2, score, cls_id = det
x1 = int(x1 / IMG_SIZE * orig_w)
y1 = int(y1 / IMG_SIZE * orig_h)
x2 = int(x2 / IMG_SIZE * orig_w)
y2 = int(y2 / IMG_SIZE * orig_h)
label = f"{int(cls_id)}:{score:.2f}"
cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0,255,0), 2)
cv2.putText(img_bgr, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# Gradio interface
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="numpy"),
gr.Dropdown(list(MODELS.keys()), label="Select Model"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.001, step=0.01, label="Confidence Threshold")
],
outputs=gr.Image(type="numpy"),
title="Industrial Defect Detection",
description="Upload an image, select the defect model, and adjust the confidence threshold.",
examples=EXAMPLES
)
if __name__ == "__main__":
demo.launch(allowed_paths=["examples"])