File size: 5,214 Bytes
403d6e9
 
 
 
 
d4f172e
403d6e9
d4f172e
403d6e9
 
 
 
 
 
 
 
 
 
 
d2af2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403d6e9
 
 
 
d4f172e
 
 
 
 
 
403d6e9
 
 
 
d4f172e
403d6e9
 
 
 
 
d4f172e
403d6e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4f172e
403d6e9
 
 
 
 
4a319e3
403d6e9
 
 
d2af2d4
 
403d6e9
 
 
15c850d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import onnxruntime as ort
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
import os

# Private models
MODELS = {
    "Dental Implant": "smartfalcon-ai/Dental-Implant-Defect-Detection",
    "Data Matrix": "smartfalcon-ai/Data-Matrix-Defect-Detection",
    "Ball Pen": "smartfalcon-ai/Ball-Pen-Defect-Detection",
    "Knit Up": "smartfalcon-ai/Knit-Up-Defect-Detection",
    "Knit Back": "smartfalcon-ai/Knit-Back-Defect-Detection",
    "Jean Back": "smartfalcon-ai/Jean-Back-Defect-Detection",
    "Jean Up": "smartfalcon-ai/Jean-Up-Defect-Detection",
    "Tire Cord": "smartfalcon-ai/Tire-Cord-Defect-Detection"
}

# Example images for each model (3 examples per task)
EXAMPLES = [
    # Dental Implant
    ["examples/dental-implant-1.jpg", "Dental Implant", 0.25],
    ["examples/dental-implant-2.jpg", "Dental Implant", 0.25],
    ["examples/dental-implant-3.jpg", "Dental Implant", 0.25],
    # Data Matrix
    ["examples/data-matrix-1.jpg", "Data Matrix", 0.25],
    ["examples/data-matrix-2.jpg", "Data Matrix", 0.25],
    ["examples/data-matrix-3.jpg", "Data Matrix", 0.25],
    # Ball Pen
    ["examples/ball-pen-1.jpg", "Ball Pen", 0.25],
    ["examples/ball-pen-2.jpg", "Ball Pen", 0.25],
    ["examples/ball-pen-3.jpg", "Ball Pen", 0.25],
    # Knit Up
    ["examples/knit-up-1.jpg", "Knit Up", 0.25],
    ["examples/knit-up-2.jpg", "Knit Up", 0.25],
    ["examples/knit-up-3.jpg", "Knit Up", 0.25],
    # Knit Back
    ["examples/knit-back-1.jpg", "Knit Back", 0.25],
    ["examples/knit-back-2.jpg", "Knit Back", 0.25],
    ["examples/knit-back-3.jpg", "Knit Back", 0.25],
    # Jean Back
    ["examples/jean-back-1.jpg", "Jean Back", 0.25],
    ["examples/jean-back-2.jpg", "Jean Back", 0.25],
    ["examples/jean-back-3.jpg", "Jean Back", 0.25],
    # Jean Up
    ["examples/jean-up-1.jpg", "Jean Up", 0.25],
    ["examples/jean-up-2.jpg", "Jean Up", 0.25],
    ["examples/jean-up-3.jpg", "Jean Up", 0.25],
    # Tire Cord
    ["examples/tire-cord-1.jpg", "Tire Cord", 0.25],
    ["examples/tire-cord-2.jpg", "Tire Cord", 0.25],
    ["examples/tire-cord-3.jpg", "Tire Cord", 0.25],
]

sessions = {}

def get_session(model_name):
    if model_name not in sessions:
        hf_token = os.environ.get("HUGGINGFACE_TOKEN", None)
        model_path = hf_hub_download(
            repo_id=MODELS[model_name],
            filename="best.onnx",
            token=hf_token
        )
        sessions[model_name] = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
    return sessions[model_name]

IMG_SIZE = 640
IOU_THRESHOLD = 0.45

def preprocess(img):
    h, w = img.shape[:2]
    img_resized = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    img_resized = img_resized.astype(np.float32) / 255.0
    img_resized = img_resized.transpose(2, 0, 1)
    img_resized = np.expand_dims(img_resized, 0)
    return img_resized, w, h

def xywh2xyxy(x):
    y = np.copy(x)
    y[:,0] = x[:,0] - x[:,2]/2
    y[:,1] = x[:,1] - x[:,3]/2
    y[:,2] = x[:,0] + x[:,2]/2
    y[:,3] = x[:,1] + x[:,3]/2
    return y

def non_max_suppression(preds, conf_thres=0.25, iou_thres=0.45):
    preds = preds[0]
    preds = preds[preds[:,4] > conf_thres]
    if preds.shape[0]==0:
        return []

    boxes = xywh2xyxy(preds[:, :4])
    scores = preds[:,4]
    class_scores = preds[:,5:]
    cls_ids = np.argmax(class_scores, axis=1)
    cls_conf = class_scores.max(axis=1)
    final_scores = scores * cls_conf

    indices = cv2.dnn.NMSBoxes(
        bboxes=boxes.tolist(),
        scores=final_scores.tolist(),
        score_threshold=conf_thres,
        nms_threshold=iou_thres
    )
    if len(indices)==0:
        return []

    indices = indices.flatten()
    output=[]
    for idx in indices:
        x1, y1, x2, y2 = boxes[idx]
        output.append([x1, y1, x2, y2, final_scores[idx], cls_ids[idx]])
    return output

def infer(image, model_name, conf_threshold):
    session = get_session(model_name)
    img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    blob, orig_w, orig_h = preprocess(img_bgr)
    preds = session.run(None, {"images": blob})[0]
    detections = non_max_suppression(preds, conf_threshold, IOU_THRESHOLD)

    for det in detections:
        x1, y1, x2, y2, score, cls_id = det
        x1 = int(x1 / IMG_SIZE * orig_w)
        y1 = int(y1 / IMG_SIZE * orig_h)
        x2 = int(x2 / IMG_SIZE * orig_w)
        y2 = int(y2 / IMG_SIZE * orig_h)

        label = f"{int(cls_id)}:{score:.2f}"
        cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0,255,0), 2)
        cv2.putText(img_bgr, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)

    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

# Gradio interface
demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(type="numpy"),
        gr.Dropdown(list(MODELS.keys()), label="Select Model"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.001, step=0.01, label="Confidence Threshold")
    ],
    outputs=gr.Image(type="numpy"),
    title="Industrial Defect Detection",
    description="Upload an image, select the defect model, and adjust the confidence threshold.",
    examples=EXAMPLES
)

if __name__ == "__main__":
    demo.launch(allowed_paths=["examples"])