Desmond-Dong commited on
Commit
273c295
·
1 Parent(s): 5ce6940

fix: 重写手势检测器,使用正确的模型输入输出格式

Browse files

参考 reference/dynamic_gestures/onnx_models.py:
- 手部检测器: 320x240 输入, 输出 boxes/labels/probs
- 分类器: 128x128 正方形裁剪输入
- 预处理: (image - 127) / 128

reachy_mini_ha_voice/gesture_detector.py CHANGED
@@ -1,390 +0,0 @@
1
- """Gesture detection using HaGRID ONNX models.
2
-
3
- Uses models from ai-forever/dynamic_gestures:
4
- - hand_detector.onnx (~1.2MB): Detects hand bounding boxes
5
- - crops_classifier.onnx (~0.4MB): Classifies hand gestures (18 HaGRID classes)
6
-
7
- Total size: ~1.6MB - optimized for Raspberry Pi CM4.
8
- """
9
-
10
- from __future__ import annotations
11
- import logging
12
- import time
13
- from enum import Enum
14
- from pathlib import Path
15
- from typing import Optional, Callable, Dict, Tuple, List
16
-
17
- import cv2
18
- import numpy as np
19
- from numpy.typing import NDArray
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- class Gesture(Enum):
25
- """HaGRID gesture classes."""
26
- NONE = "no_gesture"
27
- CALL = "call"
28
- DISLIKE = "dislike"
29
- FIST = "fist"
30
- FOUR = "four"
31
- LIKE = "like"
32
- MUTE = "mute"
33
- OK = "ok"
34
- ONE = "one"
35
- PALM = "palm"
36
- PEACE = "peace"
37
- PEACE_INVERTED = "peace_inverted"
38
- ROCK = "rock"
39
- STOP = "stop"
40
- STOP_INVERTED = "stop_inverted"
41
- THREE = "three"
42
- THREE2 = "three2"
43
- TWO_UP = "two_up"
44
- TWO_UP_INVERTED = "two_up_inverted"
45
-
46
-
47
- # HaGRID class names in order (from crops_classifier output)
48
- _HAGRID_CLASSES = [
49
- "call", "dislike", "fist", "four", "like", "mute", "ok", "one",
50
- "palm", "peace", "peace_inverted", "rock", "stop", "stop_inverted",
51
- "three", "three2", "two_up", "two_up_inverted"
52
- ]
53
-
54
- _NAME_TO_GESTURE = {name: Gesture(name) for name in _HAGRID_CLASSES}
55
-
56
-
57
- class GestureDetector:
58
- """Gesture detector using HaGRID ONNX models.
59
-
60
- Two-stage pipeline:
61
- 1. hand_detector.onnx - finds hand bounding box
62
- 2. crops_classifier.onnx - classifies gesture from cropped hand
63
-
64
- Optimized for Raspberry Pi CM4 (~1.6MB total).
65
- """
66
-
67
- def __init__(
68
- self,
69
- confidence_threshold: float = 0.6,
70
- detection_threshold: float = 0.5,
71
- ) -> None:
72
- """Initialize gesture detector.
73
-
74
- Args:
75
- confidence_threshold: Min confidence for gesture classification
76
- detection_threshold: Min confidence for hand detection
77
- """
78
- self._confidence_threshold = confidence_threshold
79
- self._detection_threshold = detection_threshold
80
-
81
- # Model paths
82
- models_dir = Path(__file__).parent / "models"
83
- self._detector_path = models_dir / "hand_detector.onnx"
84
- self._classifier_path = models_dir / "crops_classifier.onnx"
85
-
86
- self._detector = None
87
- self._classifier = None
88
- self._available = False
89
- self._model_load_error: Optional[str] = None
90
-
91
- # Callbacks
92
- self._callbacks: Dict[Gesture, Optional[Callable[[], None]]] = {
93
- g: None for g in Gesture if g != Gesture.NONE
94
- }
95
-
96
- # State tracking
97
- self._last_gesture = Gesture.NONE
98
- self._current_gesture = Gesture.NONE
99
- self._gesture_start_time: Optional[float] = None
100
- self._gesture_hold_threshold = 0.5 # seconds to hold
101
- self._gesture_cooldown = 1.5 # seconds between triggers
102
- self._last_trigger_time: float = 0
103
- self._gesture_clear_delay = 2.0
104
- self._last_gesture_time: float = 0
105
-
106
- # Load models
107
- self._load_models()
108
-
109
- def _load_models(self) -> None:
110
- """Load ONNX models."""
111
- try:
112
- import onnxruntime as ort
113
- except ImportError:
114
- self._model_load_error = "onnxruntime not installed"
115
- logger.warning("Gesture detection disabled - pip install onnxruntime")
116
- return
117
-
118
- if not self._detector_path.exists():
119
- self._model_load_error = f"Model not found: {self._detector_path}"
120
- logger.warning("Gesture detection disabled - %s", self._model_load_error)
121
- return
122
- if not self._classifier_path.exists():
123
- self._model_load_error = f"Model not found: {self._classifier_path}"
124
- logger.warning("Gesture detection disabled - %s", self._model_load_error)
125
- return
126
-
127
- try:
128
- providers = ['CPUExecutionProvider']
129
- logger.info("Loading gesture models...")
130
- self._detector = ort.InferenceSession(
131
- str(self._detector_path), providers=providers
132
- )
133
- self._classifier = ort.InferenceSession(
134
- str(self._classifier_path), providers=providers
135
- )
136
-
137
- # Log model input/output info
138
- det_inputs = self._detector.get_inputs()
139
- det_outputs = self._detector.get_outputs()
140
- logger.info("Hand detector - inputs: %s, outputs: %s",
141
- [(i.name, i.shape) for i in det_inputs],
142
- [(o.name, o.shape) for o in det_outputs])
143
-
144
- cls_inputs = self._classifier.get_inputs()
145
- cls_outputs = self._classifier.get_outputs()
146
- logger.info("Classifier - inputs: %s, outputs: %s",
147
- [(i.name, i.shape) for i in cls_inputs],
148
- [(o.name, o.shape) for o in cls_outputs])
149
-
150
- self._available = True
151
- logger.info("Gesture detection ready (18 HaGRID classes)")
152
- except Exception as e:
153
- self._model_load_error = str(e)
154
- logger.error("Failed to load gesture models: %s", e)
155
-
156
- @property
157
- def is_available(self) -> bool:
158
- """Check if gesture detector is ready."""
159
- return self._available
160
-
161
- @property
162
- def current_gesture(self) -> Gesture:
163
- """Get current detected gesture."""
164
- return self._current_gesture
165
-
166
- def set_callback(self, gesture: Gesture, callback: Optional[Callable[[], None]]) -> None:
167
- """Set callback for a specific gesture."""
168
- if gesture != Gesture.NONE:
169
- self._callbacks[gesture] = callback
170
-
171
- def set_callbacks(
172
- self,
173
- on_like: Optional[Callable[[], None]] = None,
174
- on_dislike: Optional[Callable[[], None]] = None,
175
- on_stop: Optional[Callable[[], None]] = None,
176
- on_peace: Optional[Callable[[], None]] = None,
177
- on_ok: Optional[Callable[[], None]] = None,
178
- on_call: Optional[Callable[[], None]] = None,
179
- on_fist: Optional[Callable[[], None]] = None,
180
- on_rock: Optional[Callable[[], None]] = None,
181
- on_one: Optional[Callable[[], None]] = None,
182
- on_palm: Optional[Callable[[], None]] = None,
183
- on_mute: Optional[Callable[[], None]] = None,
184
- ) -> None:
185
- """Set callbacks for common gestures."""
186
- self._callbacks[Gesture.LIKE] = on_like
187
- self._callbacks[Gesture.DISLIKE] = on_dislike
188
- self._callbacks[Gesture.STOP] = on_stop
189
- self._callbacks[Gesture.PEACE] = on_peace
190
- self._callbacks[Gesture.OK] = on_ok
191
- self._callbacks[Gesture.CALL] = on_call
192
- self._callbacks[Gesture.FIST] = on_fist
193
- self._callbacks[Gesture.ROCK] = on_rock
194
- self._callbacks[Gesture.ONE] = on_one
195
- self._callbacks[Gesture.PALM] = on_palm
196
- self._callbacks[Gesture.MUTE] = on_mute
197
-
198
-
199
- def _preprocess_detector(self, frame: NDArray[np.uint8]) -> NDArray[np.float32]:
200
- """Preprocess frame for hand detector."""
201
- # Resize to model input size (assuming 320x320)
202
- img = cv2.resize(frame, (320, 320))
203
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
204
- img = img.astype(np.float32) / 255.0
205
- img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
206
- img = np.expand_dims(img, axis=0) # Add batch dim
207
- return img
208
-
209
- def _preprocess_classifier(self, crop: NDArray[np.uint8]) -> NDArray[np.float32]:
210
- """Preprocess cropped hand for classifier."""
211
- # Resize to classifier input size (assuming 224x224)
212
- img = cv2.resize(crop, (224, 224))
213
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
214
- img = img.astype(np.float32) / 255.0
215
- # Normalize with ImageNet mean/std
216
- mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
217
- std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
218
- img = (img - mean) / std
219
- img = np.transpose(img, (2, 0, 1)) # HWC -> CHW
220
- img = np.expand_dims(img, axis=0) # Add batch dim
221
- return img
222
-
223
- def _detect_hand(self, frame: NDArray[np.uint8]) -> Optional[Tuple[int, int, int, int]]:
224
- """Detect hand bounding box in frame.
225
-
226
- Returns:
227
- (x1, y1, x2, y2) or None if no hand detected
228
- """
229
- if self._detector is None:
230
- return None
231
-
232
- h, w = frame.shape[:2]
233
- input_tensor = self._preprocess_detector(frame)
234
-
235
- # Run detector
236
- input_name = self._detector.get_inputs()[0].name
237
- outputs = self._detector.run(None, {input_name: input_tensor})
238
-
239
- # Debug: log output shape (only once)
240
- if not hasattr(self, '_logged_detector_shape'):
241
- logger.info("Hand detector output: %d tensors, shapes=%s",
242
- len(outputs), [o.shape for o in outputs])
243
- self._logged_detector_shape = True
244
-
245
- # Parse output (format depends on model, adjust as needed)
246
- # Assuming output is [batch, num_detections, 5] where 5 = [x1, y1, x2, y2, conf]
247
- detections = outputs[0]
248
-
249
- if len(detections.shape) == 3:
250
- detections = detections[0] # Remove batch dim
251
-
252
- # Find best detection above threshold
253
- best_box = None
254
- best_conf = self._detection_threshold
255
-
256
- for det in detections:
257
- if len(det) >= 5:
258
- conf = det[4]
259
- if conf > best_conf:
260
- best_conf = conf
261
- # Scale coordinates to original frame size
262
- x1 = int(det[0] * w / 320)
263
- y1 = int(det[1] * h / 320)
264
- x2 = int(det[2] * w / 320)
265
- y2 = int(det[3] * h / 320)
266
- # Clamp to frame bounds
267
- x1 = max(0, min(w, x1))
268
- y1 = max(0, min(h, y1))
269
- x2 = max(0, min(w, x2))
270
- y2 = max(0, min(h, y2))
271
- if x2 > x1 and y2 > y1:
272
- best_box = (x1, y1, x2, y2)
273
-
274
- return best_box
275
-
276
- def _classify_gesture(self, crop: NDArray[np.uint8]) -> Tuple[Gesture, float]:
277
- """Classify gesture from cropped hand image.
278
-
279
- Returns:
280
- (gesture, confidence)
281
- """
282
- if self._classifier is None:
283
- return Gesture.NONE, 0.0
284
-
285
- input_tensor = self._preprocess_classifier(crop)
286
-
287
- # Run classifier
288
- input_name = self._classifier.get_inputs()[0].name
289
- outputs = self._classifier.run(None, {input_name: input_tensor})
290
-
291
- # Get probabilities (softmax)
292
- logits = outputs[0][0]
293
- probs = np.exp(logits) / np.sum(np.exp(logits))
294
-
295
- # Get top prediction
296
- idx = np.argmax(probs)
297
- conf = probs[idx]
298
-
299
- if idx < len(_HAGRID_CLASSES) and conf >= self._confidence_threshold:
300
- gesture_name = _HAGRID_CLASSES[idx]
301
- return _NAME_TO_GESTURE.get(gesture_name, Gesture.NONE), float(conf)
302
-
303
- return Gesture.NONE, float(conf)
304
-
305
-
306
- def detect(self, frame: NDArray[np.uint8]) -> Tuple[Gesture, float]:
307
- """Detect gesture in frame.
308
-
309
- Args:
310
- frame: Input image (BGR format from OpenCV)
311
-
312
- Returns:
313
- Tuple of (gesture, confidence)
314
- """
315
- if not self.is_available:
316
- return Gesture.NONE, 0.0
317
-
318
- try:
319
- # Step 1: Detect hand
320
- box = self._detect_hand(frame)
321
- if box is None:
322
- return Gesture.NONE, 0.0
323
-
324
- # Step 2: Crop hand region
325
- x1, y1, x2, y2 = box
326
- crop = frame[y1:y2, x1:x2]
327
-
328
- if crop.size == 0:
329
- return Gesture.NONE, 0.0
330
-
331
- # Step 3: Classify gesture
332
- return self._classify_gesture(crop)
333
-
334
- except Exception as e:
335
- logger.debug("Gesture detection error: %s", e)
336
- return Gesture.NONE, 0.0
337
-
338
- def process_frame(self, frame: NDArray[np.uint8]) -> Optional[Gesture]:
339
- """Process frame and trigger callbacks if gesture held.
340
-
341
- Args:
342
- frame: Input image (BGR format)
343
-
344
- Returns:
345
- Triggered gesture or None
346
- """
347
- gesture, confidence = self.detect(frame)
348
- now = time.time()
349
-
350
- # Update current gesture for display
351
- if gesture != Gesture.NONE:
352
- self._current_gesture = gesture
353
- self._last_gesture_time = now
354
- elif now - self._last_gesture_time > self._gesture_clear_delay:
355
- self._current_gesture = Gesture.NONE
356
-
357
- # Check cooldown
358
- if now - self._last_trigger_time < self._gesture_cooldown:
359
- return None
360
-
361
- # Track gesture hold time
362
- if gesture != self._last_gesture:
363
- self._last_gesture = gesture
364
- self._gesture_start_time = now if gesture != Gesture.NONE else None
365
- return None
366
-
367
- # Check if gesture held long enough
368
- if gesture != Gesture.NONE and self._gesture_start_time:
369
- if now - self._gesture_start_time >= self._gesture_hold_threshold:
370
- self._last_trigger_time = now
371
- self._gesture_start_time = None
372
-
373
- # Trigger callback
374
- callback = self._callbacks.get(gesture)
375
- if callback:
376
- logger.info("Gesture triggered: %s (%.1f%%)",
377
- gesture.value, confidence * 100)
378
- try:
379
- callback()
380
- except Exception as e:
381
- logger.error("Gesture callback error: %s", e)
382
- return gesture
383
-
384
- return None
385
-
386
- def close(self) -> None:
387
- """Release resources."""
388
- self._detector = None
389
- self._classifier = None
390
- self._available = False