Desmond-Dong commited on
Commit
e5e9146
·
1 Parent(s): 273c295

fix: 重写gesture_detector.py,修正模型输出格式

Browse files

- 检测器输出: boxes(N,4), labels(1,), scores(N,) 无batch维度
- 分类器输出: labels(batch,45)
- 预处理: (img - 127) / 128

reachy_mini_ha_voice/gesture_detector.py CHANGED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gesture detection using HaGRID ONNX models."""
2
+
3
+ from __future__ import annotations
4
+ import logging
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class Gesture(Enum):
17
+ NONE = "no_gesture"
18
+ CALL = "call"
19
+ DISLIKE = "dislike"
20
+ FIST = "fist"
21
+ FOUR = "four"
22
+ LIKE = "like"
23
+ MUTE = "mute"
24
+ OK = "ok"
25
+ ONE = "one"
26
+ PALM = "palm"
27
+ PEACE = "peace"
28
+ PEACE_INVERTED = "peace_inverted"
29
+ ROCK = "rock"
30
+ STOP = "stop"
31
+ STOP_INVERTED = "stop_inverted"
32
+ THREE = "three"
33
+ THREE2 = "three2"
34
+ TWO_UP = "two_up"
35
+ TWO_UP_INVERTED = "two_up_inverted"
36
+
37
+
38
+ _GESTURE_CLASSES = [
39
+ 'hand_down', 'hand_right', 'hand_left', 'thumb_index', 'thumb_left',
40
+ 'thumb_right', 'thumb_down', 'half_up', 'half_left', 'half_right',
41
+ 'half_down', 'part_hand_heart', 'part_hand_heart2', 'fist_inverted',
42
+ 'two_left', 'two_right', 'two_down', 'grabbing', 'grip', 'point',
43
+ 'call', 'three3', 'little_finger', 'middle_finger', 'dislike', 'fist',
44
+ 'four', 'like', 'mute', 'ok', 'one', 'palm', 'peace', 'peace_inverted',
45
+ 'rock', 'stop', 'stop_inverted', 'three', 'three2', 'two_up',
46
+ 'two_up_inverted', 'three_gun', 'one_left', 'one_right', 'one_down'
47
+ ]
48
+
49
+ _NAME_TO_GESTURE = {
50
+ 'call': Gesture.CALL, 'dislike': Gesture.DISLIKE, 'fist': Gesture.FIST,
51
+ 'four': Gesture.FOUR, 'like': Gesture.LIKE, 'mute': Gesture.MUTE,
52
+ 'ok': Gesture.OK, 'one': Gesture.ONE, 'palm': Gesture.PALM,
53
+ 'peace': Gesture.PEACE, 'peace_inverted': Gesture.PEACE_INVERTED,
54
+ 'rock': Gesture.ROCK, 'stop': Gesture.STOP,
55
+ 'stop_inverted': Gesture.STOP_INVERTED, 'three': Gesture.THREE,
56
+ 'three2': Gesture.THREE2, 'two_up': Gesture.TWO_UP,
57
+ 'two_up_inverted': Gesture.TWO_UP_INVERTED,
58
+ }
59
+
60
+
61
+ class GestureDetector:
62
+ def __init__(self, confidence_threshold: float = 0.5, detection_threshold: float = 0.5):
63
+ self._confidence_threshold = confidence_threshold
64
+ self._detection_threshold = detection_threshold
65
+ models_dir = Path(__file__).parent / "models"
66
+ self._detector_path = models_dir / "hand_detector.onnx"
67
+ self._classifier_path = models_dir / "crops_classifier.onnx"
68
+ self._detector = None
69
+ self._classifier = None
70
+ self._available = False
71
+ self._mean = np.array([127, 127, 127], dtype=np.float32)
72
+ self._std = np.array([128, 128, 128], dtype=np.float32)
73
+ self._detector_size = (320, 240)
74
+ self._classifier_size = (128, 128)
75
+ self._load_models()
76
+
77
+ def _load_models(self) -> None:
78
+ try:
79
+ import onnxruntime as ort
80
+ except ImportError:
81
+ logger.warning("onnxruntime not installed")
82
+ return
83
+ if not self._detector_path.exists() or not self._classifier_path.exists():
84
+ logger.warning("Model files not found")
85
+ return
86
+ try:
87
+ providers = ['CPUExecutionProvider']
88
+ logger.info("Loading gesture models...")
89
+ self._detector = ort.InferenceSession(str(self._detector_path), providers=providers)
90
+ self._classifier = ort.InferenceSession(str(self._classifier_path), providers=providers)
91
+ self._det_input = self._detector.get_inputs()[0].name
92
+ self._det_outputs = [o.name for o in self._detector.get_outputs()]
93
+ self._cls_input = self._classifier.get_inputs()[0].name
94
+ self._available = True
95
+ logger.info("Gesture detection ready")
96
+ except Exception as e:
97
+ logger.error("Failed to load models: %s", e)
98
+
99
+ @property
100
+ def is_available(self) -> bool:
101
+ return self._available
102
+
103
+ def _preprocess(self, frame: NDArray, size: Tuple[int, int]) -> NDArray:
104
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
+ img = cv2.resize(img, size)
106
+ img = (img.astype(np.float32) - self._mean) / self._std
107
+ img = np.transpose(img, [2, 0, 1])
108
+ return np.expand_dims(img, axis=0)
109
+
110
+ def _detect_hand(self, frame: NDArray) -> Optional[Tuple[int, int, int, int, float]]:
111
+ if self._detector is None:
112
+ return None
113
+ h, w = frame.shape[:2]
114
+ inp = self._preprocess(frame, self._detector_size)
115
+ outs = self._detector.run(self._det_outputs, {self._det_input: inp})
116
+ boxes = outs[0]
117
+ scores = outs[2]
118
+ if len(boxes) == 0:
119
+ return None
120
+ best_i, best_c = -1, self._detection_threshold
121
+ for i, c in enumerate(scores):
122
+ if c > best_c:
123
+ best_c, best_i = float(c), i
124
+ if best_i < 0:
125
+ return None
126
+ b = boxes[best_i]
127
+ x1, y1 = int(b[0] * w), int(b[1] * h)
128
+ x2, y2 = int(b[2] * w), int(b[3] * h)
129
+ x1, y1 = max(0, x1), max(0, y1)
130
+ x2, y2 = min(w-1, x2), min(h-1, y2)
131
+ if x2 <= x1 or y2 <= y1:
132
+ return None
133
+ return (x1, y1, x2, y2, best_c)
134
+
135
+ def _get_square_crop(self, frame: NDArray, box: Tuple[int, int, int, int]) -> NDArray:
136
+ h, w = frame.shape[:2]
137
+ x1, y1, x2, y2 = box
138
+ bw, bh = x2 - x1, y2 - y1
139
+ if bh < bw:
140
+ y1, y2 = y1 - (bw - bh) // 2, y1 - (bw - bh) // 2 + bw
141
+ elif bh > bw:
142
+ x1, x2 = x1 - (bh - bw) // 2, x1 - (bh - bw) // 2 + bh
143
+ x1, y1 = max(0, x1), max(0, y1)
144
+ x2, y2 = min(w-1, x2), min(h-1, y2)
145
+ return frame[y1:y2, x1:x2]
146
+
147
+ def _classify(self, crop: NDArray) -> Tuple[Gesture, float]:
148
+ if self._classifier is None or crop.size == 0:
149
+ return Gesture.NONE, 0.0
150
+ inp = self._preprocess(crop, self._classifier_size)
151
+ logits = self._classifier.run(None, {self._cls_input: inp})[0][0]
152
+ idx = int(np.argmax(logits))
153
+ exp_l = np.exp(logits - np.max(logits))
154
+ conf = float(exp_l[idx] / np.sum(exp_l))
155
+ if idx >= len(_GESTURE_CLASSES) or conf < self._confidence_threshold:
156
+ return Gesture.NONE, conf
157
+ name = _GESTURE_CLASSES[idx]
158
+ return _NAME_TO_GESTURE.get(name, Gesture.NONE), conf
159
+
160
+ def detect(self, frame: NDArray) -> Tuple[Gesture, float]:
161
+ if not self._available:
162
+ return Gesture.NONE, 0.0
163
+ try:
164
+ det = self._detect_hand(frame)
165
+ if det is None:
166
+ return Gesture.NONE, 0.0
167
+ x1, y1, x2, y2, det_c = det
168
+ crop = self._get_square_crop(frame, (x1, y1, x2, y2))
169
+ if crop.size == 0:
170
+ return Gesture.NONE, 0.0
171
+ gest, cls_c = self._classify(crop)
172
+ if gest != Gesture.NONE:
173
+ logger.debug("Gesture: %s (det=%.2f cls=%.2f)", gest.value, det_c, cls_c)
174
+ return gest, det_c * cls_c
175
+ except Exception as e:
176
+ logger.warning("Gesture error: %s", e)
177
+ return Gesture.NONE, 0.0
178
+
179
+ def close(self) -> None:
180
+ self._detector = self._classifier = None
181
+ self._available = False