| """Voice satellite protocol for Reachy Mini."""
|
|
|
| import hashlib
|
| import logging
|
| import posixpath
|
| import shutil
|
| import time
|
| from collections.abc import Iterable
|
| from typing import Dict, Optional, Set, Union, TYPE_CHECKING
|
| from urllib.parse import urlparse, urlunparse
|
| from urllib.request import urlopen
|
|
|
| if TYPE_CHECKING:
|
| from .camera_server import MJPEGCameraServer
|
|
|
|
|
| from aioesphomeapi.api_pb2 import (
|
| ButtonCommandRequest,
|
| CameraImageRequest,
|
| DeviceInfoRequest,
|
| DeviceInfoResponse,
|
| ListEntitiesDoneResponse,
|
| ListEntitiesRequest,
|
| MediaPlayerCommandRequest,
|
| NumberCommandRequest,
|
| SelectCommandRequest,
|
| SubscribeHomeAssistantStatesRequest,
|
| SubscribeStatesRequest,
|
| SwitchCommandRequest,
|
| VoiceAssistantAnnounceFinished,
|
| VoiceAssistantAnnounceRequest,
|
| VoiceAssistantAudio,
|
| VoiceAssistantConfigurationRequest,
|
| VoiceAssistantConfigurationResponse,
|
| VoiceAssistantEventResponse,
|
| VoiceAssistantExternalWakeWord,
|
| VoiceAssistantRequest,
|
| VoiceAssistantSetConfiguration,
|
| VoiceAssistantTimerEventResponse,
|
| VoiceAssistantWakeWord,
|
| )
|
| from aioesphomeapi.model import (
|
| VoiceAssistantEventType,
|
| VoiceAssistantFeature,
|
| VoiceAssistantTimerEventType,
|
| )
|
| from google.protobuf import message
|
| from pymicro_wakeword import MicroWakeWord
|
| from pyopen_wakeword import OpenWakeWord
|
|
|
| from .api_server import APIServer
|
| from .entity import MediaPlayerEntity
|
| from .entity_registry import EntityRegistry, get_entity_key
|
| from .models import AvailableWakeWord, ServerState, WakeWordType
|
| from .util import call_all
|
| from .reachy_controller import ReachyController
|
|
|
| _LOGGER = logging.getLogger(__name__)
|
|
|
|
|
| class VoiceSatelliteProtocol(APIServer):
|
| """Voice satellite protocol handler for ESPHome."""
|
|
|
| def __init__(self, state: ServerState, camera_server: Optional["MJPEGCameraServer"] = None) -> None:
|
| super().__init__(state.name)
|
| self.state = state
|
| self.state.satellite = self
|
| self.camera_server = camera_server
|
|
|
|
|
| self.reachy_controller = ReachyController(state.reachy_mini)
|
|
|
|
|
| self._entity_registry = EntityRegistry(
|
| server=self,
|
| reachy_controller=self.reachy_controller,
|
| camera_server=camera_server,
|
| play_emotion_callback=self._play_emotion,
|
| )
|
|
|
|
|
|
|
| if not getattr(self.state, '_entities_initialized', False):
|
| if self.state.media_player_entity is None:
|
| self.state.media_player_entity = MediaPlayerEntity(
|
| server=self,
|
| key=get_entity_key("reachy_mini_media_player"),
|
| name="Media Player",
|
| object_id="reachy_mini_media_player",
|
| music_player=state.music_player,
|
| announce_player=state.tts_player,
|
| )
|
| self.state.entities.append(self.state.media_player_entity)
|
|
|
|
|
| self._entity_registry.setup_all_entities(self.state.entities)
|
|
|
|
|
| self.state._entities_initialized = True
|
| _LOGGER.info("Entities initialized: %d total", len(self.state.entities))
|
| else:
|
| _LOGGER.debug("Entities already initialized, skipping setup")
|
|
|
| for entity in self.state.entities:
|
| entity.server = self
|
|
|
| self._entity_registry.find_entity_references(self.state.entities)
|
|
|
| self._is_streaming_audio = False
|
| self._tts_url: Optional[str] = None
|
| self._tts_played = False
|
| self._continue_conversation = False
|
| self._timer_finished = False
|
| self._external_wake_words: Dict[str, VoiceAssistantExternalWakeWord] = {}
|
|
|
| def handle_voice_event(
|
| self, event_type: VoiceAssistantEventType, data: Dict[str, str]
|
| ) -> None:
|
| _LOGGER.debug("Voice event: type=%s, data=%s", event_type.name, data)
|
|
|
| if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START:
|
| self._tts_url = data.get("url")
|
| self._tts_played = False
|
| self._continue_conversation = False
|
|
|
| self._reachy_on_listening()
|
|
|
| elif event_type in (
|
| VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END,
|
| VoiceAssistantEventType.VOICE_ASSISTANT_STT_END,
|
| ):
|
| self._is_streaming_audio = False
|
|
|
| self._reachy_on_thinking()
|
|
|
| elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_PROGRESS:
|
| if data.get("tts_start_streaming") == "1":
|
|
|
| self.play_tts()
|
|
|
| elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
| if data.get("continue_conversation") == "1":
|
| self._continue_conversation = True
|
|
|
| elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
|
|
| self._reachy_on_speaking()
|
|
|
| elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
| self._tts_url = data.get("url")
|
| self.play_tts()
|
|
|
| elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END:
|
| self._is_streaming_audio = False
|
| if not self._tts_played:
|
| self._tts_finished()
|
| self._tts_played = False
|
|
|
| self._reachy_on_idle()
|
|
|
| def handle_timer_event(
|
| self,
|
| event_type: VoiceAssistantTimerEventType,
|
| msg: VoiceAssistantTimerEventResponse,
|
| ) -> None:
|
| _LOGGER.debug("Timer event: type=%s", event_type.name)
|
|
|
| if event_type == VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED:
|
| if not self._timer_finished:
|
| self.state.active_wake_words.add(self.state.stop_word.id)
|
| self._timer_finished = True
|
| self.duck()
|
| self._play_timer_finished()
|
|
|
| self._reachy_on_timer_finished()
|
|
|
| def handle_message(self, msg: message.Message) -> Iterable[message.Message]:
|
| if isinstance(msg, VoiceAssistantEventResponse):
|
|
|
| data: Dict[str, str] = {}
|
| for arg in msg.data:
|
| data[arg.name] = arg.value
|
| self.handle_voice_event(VoiceAssistantEventType(msg.event_type), data)
|
|
|
| elif isinstance(msg, VoiceAssistantAnnounceRequest):
|
| _LOGGER.debug("Announcing: %s", msg.text)
|
| assert self.state.media_player_entity is not None
|
|
|
| urls = []
|
| if msg.preannounce_media_id:
|
| urls.append(msg.preannounce_media_id)
|
| urls.append(msg.media_id)
|
|
|
| self.state.active_wake_words.add(self.state.stop_word.id)
|
| self._continue_conversation = msg.start_conversation
|
| self.duck()
|
|
|
| yield from self.state.media_player_entity.play(
|
| urls, announcement=True, done_callback=self._tts_finished
|
| )
|
|
|
| elif isinstance(msg, VoiceAssistantTimerEventResponse):
|
| self.handle_timer_event(VoiceAssistantTimerEventType(msg.event_type), msg)
|
|
|
| elif isinstance(msg, DeviceInfoRequest):
|
| yield DeviceInfoResponse(
|
| uses_password=False,
|
| name=self.state.name,
|
| mac_address=self.state.mac_address,
|
| voice_assistant_feature_flags=(
|
| VoiceAssistantFeature.VOICE_ASSISTANT
|
| | VoiceAssistantFeature.API_AUDIO
|
| | VoiceAssistantFeature.ANNOUNCE
|
| | VoiceAssistantFeature.START_CONVERSATION
|
| | VoiceAssistantFeature.TIMERS
|
| ),
|
| )
|
|
|
| elif isinstance(
|
| msg,
|
| (
|
| ListEntitiesRequest,
|
| SubscribeHomeAssistantStatesRequest,
|
| SubscribeStatesRequest,
|
| MediaPlayerCommandRequest,
|
| NumberCommandRequest,
|
| SwitchCommandRequest,
|
| SelectCommandRequest,
|
| ButtonCommandRequest,
|
| CameraImageRequest,
|
| ),
|
| ):
|
| for entity in self.state.entities:
|
| yield from entity.handle_message(msg)
|
|
|
| if isinstance(msg, ListEntitiesRequest):
|
| yield ListEntitiesDoneResponse()
|
|
|
| elif isinstance(msg, VoiceAssistantConfigurationRequest):
|
| available_wake_words = [
|
| VoiceAssistantWakeWord(
|
| id=ww.id,
|
| wake_word=ww.wake_word,
|
| trained_languages=ww.trained_languages,
|
| )
|
| for ww in self.state.available_wake_words.values()
|
| ]
|
|
|
| for eww in msg.external_wake_words:
|
| if eww.model_type != "micro":
|
| continue
|
|
|
| available_wake_words.append(
|
| VoiceAssistantWakeWord(
|
| id=eww.id,
|
| wake_word=eww.wake_word,
|
| trained_languages=eww.trained_languages,
|
| )
|
| )
|
| self._external_wake_words[eww.id] = eww
|
|
|
| yield VoiceAssistantConfigurationResponse(
|
| available_wake_words=available_wake_words,
|
| active_wake_words=[
|
| ww.id
|
| for ww in self.state.wake_words.values()
|
| if ww.id in self.state.active_wake_words
|
| ],
|
| max_active_wake_words=2,
|
| )
|
|
|
| _LOGGER.info("Connected to Home Assistant")
|
|
|
| elif isinstance(msg, VoiceAssistantSetConfiguration):
|
|
|
| active_wake_words: Set[str] = set()
|
|
|
| for wake_word_id in msg.active_wake_words:
|
| if wake_word_id in self.state.wake_words:
|
|
|
| active_wake_words.add(wake_word_id)
|
| continue
|
|
|
| model_info = self.state.available_wake_words.get(wake_word_id)
|
| if not model_info:
|
|
|
| external_wake_word = self._external_wake_words.get(wake_word_id)
|
| if not external_wake_word:
|
| continue
|
|
|
| model_info = self._download_external_wake_word(external_wake_word)
|
| if not model_info:
|
| continue
|
|
|
| self.state.available_wake_words[wake_word_id] = model_info
|
|
|
| _LOGGER.debug("Loading wake word: %s", model_info.wake_word_path)
|
| self.state.wake_words[wake_word_id] = model_info.load()
|
| _LOGGER.info("Wake word set: %s", wake_word_id)
|
| active_wake_words.add(wake_word_id)
|
| break
|
|
|
| self.state.active_wake_words = active_wake_words
|
| _LOGGER.debug("Active wake words: %s", active_wake_words)
|
|
|
| self.state.preferences.active_wake_words = list(active_wake_words)
|
| self.state.save_preferences()
|
| self.state.wake_words_changed = True
|
|
|
| def handle_audio(self, audio_chunk: bytes) -> None:
|
| if not self._is_streaming_audio:
|
| return
|
| self.send_messages([VoiceAssistantAudio(data=audio_chunk)])
|
|
|
| def wakeup(self, wake_word: Union[MicroWakeWord, OpenWakeWord]) -> None:
|
| if self._timer_finished:
|
|
|
| self._timer_finished = False
|
| self.state.tts_player.stop()
|
| _LOGGER.debug("Stopping timer finished sound")
|
| return
|
|
|
| wake_word_phrase = wake_word.wake_word
|
| _LOGGER.debug("Detected wake word: %s", wake_word_phrase)
|
|
|
| self.send_messages(
|
| [VoiceAssistantRequest(start=True, wake_word_phrase=wake_word_phrase)]
|
| )
|
| self.duck()
|
| self._is_streaming_audio = True
|
| self.state.tts_player.play(self.state.wakeup_sound)
|
|
|
|
|
| self._update_doa_entities()
|
|
|
| def stop(self) -> None:
|
| self.state.active_wake_words.discard(self.state.stop_word.id)
|
| self.state.tts_player.stop()
|
|
|
| if self._timer_finished:
|
| self._timer_finished = False
|
| _LOGGER.debug("Stopping timer finished sound")
|
| else:
|
| _LOGGER.debug("TTS response stopped manually")
|
|
|
| self._tts_finished()
|
|
|
| def play_tts(self) -> None:
|
| if (not self._tts_url) or self._tts_played:
|
| return
|
|
|
| self._tts_played = True
|
| _LOGGER.debug("Playing TTS response: %s", self._tts_url)
|
|
|
| self.state.active_wake_words.add(self.state.stop_word.id)
|
| self.state.tts_player.play(self._tts_url, done_callback=self._tts_finished)
|
|
|
| def duck(self) -> None:
|
| _LOGGER.debug("Ducking music")
|
| self.state.music_player.duck()
|
|
|
| def unduck(self) -> None:
|
| _LOGGER.debug("Unducking music")
|
| self.state.music_player.unduck()
|
|
|
| def _tts_finished(self) -> None:
|
| self.state.active_wake_words.discard(self.state.stop_word.id)
|
| self.send_messages([VoiceAssistantAnnounceFinished()])
|
|
|
| if self._continue_conversation:
|
| self.send_messages([VoiceAssistantRequest(start=True)])
|
| self._is_streaming_audio = True
|
| _LOGGER.debug("Continuing conversation")
|
| else:
|
| self.unduck()
|
| _LOGGER.debug("TTS response finished")
|
|
|
| self._reachy_on_idle()
|
|
|
| def _play_timer_finished(self) -> None:
|
| if not self._timer_finished:
|
| self.unduck()
|
| return
|
|
|
| self.state.tts_player.play(
|
| self.state.timer_finished_sound,
|
| done_callback=lambda: call_all(
|
| lambda: time.sleep(1.0), self._play_timer_finished
|
| ),
|
| )
|
|
|
| def connection_lost(self, exc):
|
| super().connection_lost(exc)
|
| _LOGGER.info("Disconnected from Home Assistant")
|
|
|
| self._is_streaming_audio = False
|
| self._tts_url = None
|
| self._tts_played = False
|
| self._continue_conversation = False
|
|
|
| def _download_external_wake_word(
|
| self, external_wake_word: VoiceAssistantExternalWakeWord
|
| ) -> Optional[AvailableWakeWord]:
|
| eww_dir = self.state.download_dir / "external_wake_words"
|
| eww_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| config_path = eww_dir / f"{external_wake_word.id}.json"
|
| should_download_config = not config_path.exists()
|
|
|
|
|
| model_path = eww_dir / f"{external_wake_word.id}.tflite"
|
| should_download_model = True
|
|
|
| if model_path.exists():
|
| model_size = model_path.stat().st_size
|
| if model_size == external_wake_word.model_size:
|
| with open(model_path, "rb") as model_file:
|
| model_hash = hashlib.sha256(model_file.read()).hexdigest()
|
|
|
| if model_hash == external_wake_word.model_hash:
|
| should_download_model = False
|
| _LOGGER.debug(
|
| "Model size and hash match for %s. Skipping download.",
|
| external_wake_word.id,
|
| )
|
|
|
| if should_download_config or should_download_model:
|
|
|
| _LOGGER.debug("Downloading %s to %s", external_wake_word.url, config_path)
|
| with urlopen(external_wake_word.url) as request:
|
| if request.status != 200:
|
| _LOGGER.warning(
|
| "Failed to download: %s, status=%s",
|
| external_wake_word.url,
|
| request.status,
|
| )
|
| return None
|
|
|
| with open(config_path, "wb") as model_file:
|
| shutil.copyfileobj(request, model_file)
|
|
|
| if should_download_model:
|
|
|
| parsed_url = urlparse(external_wake_word.url)
|
| parsed_url = parsed_url._replace(
|
| path=posixpath.join(posixpath.dirname(parsed_url.path), model_path.name)
|
| )
|
| model_url = urlunparse(parsed_url)
|
|
|
| _LOGGER.debug("Downloading %s to %s", model_url, model_path)
|
| with urlopen(model_url) as request:
|
| if request.status != 200:
|
| _LOGGER.warning(
|
| "Failed to download: %s, status=%s", model_url, request.status
|
| )
|
| return None
|
|
|
| with open(model_path, "wb") as model_file:
|
| shutil.copyfileobj(request, model_file)
|
|
|
| return AvailableWakeWord(
|
| id=external_wake_word.id,
|
| type=WakeWordType.MICRO_WAKE_WORD,
|
| wake_word=external_wake_word.wake_word,
|
| trained_languages=external_wake_word.trained_languages,
|
| wake_word_path=config_path,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def _update_doa_entities(self) -> None:
|
| """Update DOA and speech detection entities in Home Assistant."""
|
| try:
|
| if self._entity_registry.doa_angle_entity is not None:
|
| self._entity_registry.doa_angle_entity.update_state()
|
| _LOGGER.debug("DOA angle entity updated")
|
| if self._entity_registry.speech_detected_entity is not None:
|
| self._entity_registry.speech_detected_entity.update_state()
|
| _LOGGER.debug("Speech detected entity updated")
|
| except Exception as e:
|
| _LOGGER.error("Error updating DOA entities: %s", e)
|
|
|
| def _reachy_on_listening(self) -> None:
|
| """Called when listening for speech (HA state: Listening)."""
|
| if not self.state.motion_enabled or not self.state.reachy_mini:
|
| return
|
| try:
|
| _LOGGER.debug("Reachy Mini: Listening animation")
|
| if self.state.motion:
|
| self.state.motion.on_listening()
|
| except Exception as e:
|
| _LOGGER.error("Reachy Mini motion error: %s", e)
|
|
|
| def _reachy_on_thinking(self) -> None:
|
| """Called when processing speech (HA state: Processing)."""
|
| if not self.state.motion_enabled or not self.state.reachy_mini:
|
| return
|
| try:
|
| _LOGGER.debug("Reachy Mini: Thinking animation")
|
| if self.state.motion:
|
| self.state.motion.on_thinking()
|
| except Exception as e:
|
| _LOGGER.error("Reachy Mini motion error: %s", e)
|
|
|
| def _reachy_on_speaking(self) -> None:
|
| """Called when TTS is playing (HA state: Responding)."""
|
| if not self.state.motion_enabled or not self.state.reachy_mini:
|
| return
|
| try:
|
| _LOGGER.debug("Reachy Mini: Speaking animation")
|
| if self.state.motion:
|
| self.state.motion.on_speaking_start()
|
| except Exception as e:
|
| _LOGGER.error("Reachy Mini motion error: %s", e)
|
|
|
| def _reachy_on_idle(self) -> None:
|
| """Called when returning to idle state (HA state: Idle)."""
|
| if not self.state.motion_enabled or not self.state.reachy_mini:
|
| return
|
| try:
|
| _LOGGER.debug("Reachy Mini: Idle animation")
|
| if self.state.motion:
|
| self.state.motion.on_idle()
|
| except Exception as e:
|
| _LOGGER.error("Reachy Mini motion error: %s", e)
|
|
|
| def _reachy_on_timer_finished(self) -> None:
|
| """Called when a timer finishes."""
|
| if not self.state.motion_enabled or not self.state.reachy_mini:
|
| return
|
| try:
|
| _LOGGER.debug("Reachy Mini: Timer finished animation")
|
| if self.state.motion:
|
| self.state.motion.on_timer_finished()
|
| except Exception as e:
|
| _LOGGER.error("Reachy Mini motion error: %s", e)
|
|
|
| def _play_emotion(self, emotion_name: str) -> None:
|
| """Play an emotion/expression from the emotions library.
|
|
|
| Args:
|
| emotion_name: Name of the emotion (e.g., "happy1", "sad1", etc.)
|
| """
|
| try:
|
| import requests
|
|
|
|
|
| wlan_ip = "localhost"
|
| if self.state.reachy_mini is not None:
|
| try:
|
| status = self.state.reachy_mini.client.get_status(wait=False)
|
| wlan_ip = status.get('wlan_ip', 'localhost')
|
| except Exception:
|
| wlan_ip = "localhost"
|
|
|
|
|
|
|
| url = f"http://{wlan_ip}:8000/api/move/play/recorded-move-dataset/pollen-robotics/reachy-mini-emotions-library/{emotion_name}"
|
|
|
| response = requests.post(url, timeout=5)
|
| if response.status_code == 200:
|
| _LOGGER.info(f"Playing emotion: {emotion_name}")
|
| else:
|
| _LOGGER.warning(f"Failed to play emotion {emotion_name}: HTTP {response.status_code}")
|
|
|
| except Exception as e:
|
| _LOGGER.error(f"Error playing emotion {emotion_name}: {e}")
|
|
|