"""Partial ESPHome server implementation.""" import asyncio import logging from abc import abstractmethod from collections.abc import Iterable from typing import TYPE_CHECKING, List, Optional # pylint: disable=no-name-in-module from aioesphomeapi._frame_helper.packets import make_plain_text_packets from aioesphomeapi.api_pb2 import ( # type: ignore[attr-defined] AuthenticationRequest, AuthenticationResponse, DisconnectRequest, DisconnectResponse, HelloRequest, HelloResponse, PingRequest, PingResponse, ) from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO from google.protobuf import message PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} _LOGGER = logging.getLogger(__name__) class APIServer(asyncio.Protocol): """ESPHome API Server implementation.""" def __init__(self, name: str) -> None: self.name = name self._buffer: Optional[bytes] = None self._buffer_len: int = 0 self._pos: int = 0 self._transport = None self._writelines = None @abstractmethod def handle_message(self, msg: message.Message) -> Iterable[message.Message]: pass def process_packet(self, msg_type: int, packet_data: bytes) -> None: msg_class = MESSAGE_TYPE_TO_PROTO[msg_type] msg_inst = msg_class.FromString(packet_data) if isinstance(msg_inst, HelloRequest): self.send_messages( [ HelloResponse( api_version_major=1, api_version_minor=10, name=self.name, ) ] ) return if isinstance(msg_inst, AuthenticationRequest): self.send_messages([AuthenticationResponse()]) elif isinstance(msg_inst, DisconnectRequest): self.send_messages([DisconnectResponse()]) _LOGGER.debug("Disconnect requested") if self._transport: self._transport.close() self._transport = None self._writelines = None elif isinstance(msg_inst, PingRequest): self.send_messages([PingResponse()]) elif msgs := self.handle_message(msg_inst): if isinstance(msgs, message.Message): msgs = [msgs] self.send_messages(msgs) def send_messages(self, msgs: List[message.Message]): if self._writelines is None: return packets = [ (PROTO_TO_MESSAGE_TYPE[msg.__class__], msg.SerializeToString()) for msg in msgs ] packet_bytes = make_plain_text_packets(packets) self._writelines(packet_bytes) def connection_made(self, transport) -> None: self._transport = transport self._writelines = transport.writelines def data_received(self, data: bytes): if self._buffer is None: self._buffer = data self._buffer_len = len(data) else: self._buffer += data self._buffer_len += len(data) while self._buffer_len >= 3: self._pos = 0 # Read preamble, which should always 0x00 if (preamble := self._read_varuint()) != 0x00: _LOGGER.error("Incorrect preamble: %s", preamble) return if (length := self._read_varuint()) == -1: _LOGGER.error("Incorrect length: %s", length) return if (msg_type := self._read_varuint()) == -1: _LOGGER.error("Incorrect message type: %s", msg_type) return if length == 0: # Empty message (allowed) self._remove_from_buffer() self.process_packet(msg_type, b"") continue if (packet_data := self._read(length)) is None: return self._remove_from_buffer() self.process_packet(msg_type, packet_data) def _read(self, length: int) -> bytes | None: """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" new_pos = self._pos + length if self._buffer_len < new_pos: return None original_pos = self._pos self._pos = new_pos if TYPE_CHECKING: assert self._buffer is not None, "Buffer should be set" cstr = self._buffer return cstr[original_pos:new_pos] def connection_lost(self, exc): self._transport = None self._writelines = None def _read_varuint(self) -> int: """Read a varuint from the buffer or -1 if the buffer runs out of bytes.""" if not self._buffer: return -1 result = 0 bitpos = 0 cstr = self._buffer while self._buffer_len > self._pos: val = cstr[self._pos] self._pos += 1 result |= (val & 0x7F) << bitpos if (val & 0x80) == 0: return result bitpos += 7 return -1 def _remove_from_buffer(self) -> None: """Remove data from the buffer.""" end_of_frame_pos = self._pos self._buffer_len -= end_of_frame_pos if self._buffer_len == 0: self._buffer = None return if TYPE_CHECKING: assert self._buffer is not None, "Buffer should be set" cstr = self._buffer self._buffer = cstr[end_of_frame_pos : self._buffer_len + end_of_frame_pos]