diff --git a/src/smpclient/transport/serial/__init__.py b/src/smpclient/transport/serial/__init__.py new file mode 100644 index 0000000..bd60495 --- /dev/null +++ b/src/smpclient/transport/serial/__init__.py @@ -0,0 +1,7 @@ +"""Serial SMPTransports. + +In addition to UART, these transports can be used with USB CDC ACM and CAN. +""" + +from smpclient.transport.serial.encoded import SMPSerialTransport as SMPSerialTransport +from smpclient.transport.serial.unencoded import SMPSerialRawTransport as SMPSerialRawTransport diff --git a/src/smpclient/transport/serial/common.py b/src/smpclient/transport/serial/common.py new file mode 100644 index 0000000..b26cd68 --- /dev/null +++ b/src/smpclient/transport/serial/common.py @@ -0,0 +1,155 @@ +"""Shared connection management for the encoded and unencoded serial transports.""" + +import asyncio +import logging +from contextlib import contextmanager +from time import monotonic +from typing import Final, Generator, final + +try: + from serial import Serial, SerialException +except ModuleNotFoundError as e: + if e.name == "serial": + raise ImportError( + "Serial transport requires the 'serial' extra. Use smpclient[serial]" + ) from e + raise +from typing_extensions import override + +from smpclient.transport import SMPTransport, SMPTransportDisconnected + +logger = logging.getLogger(__name__) + + +class _SerialTransportBase(SMPTransport): + """Connection-management base class for serial-port-backed SMP transports. + + Holds the `pyserial` `Serial` instance, the open/retry connect loop, disconnect, + and the small TX/RX helpers that wrap `SerialException` into + `SMPTransportDisconnected`. + + Subclasses implement `send` and `receive` with their framing of choice and may + override `_reset_state` to clear per-connection state on `connect`. + """ + + _POLLING_INTERVAL_S: Final = 0.005 + _CONNECTION_RETRY_INTERVAL_S: Final = 0.500 + + def __init__( + self, + baudrate: int = 115200, + bytesize: int = 8, + parity: str = "N", + stopbits: float = 1, + timeout: float | None = None, + xonxoff: bool = False, + rtscts: bool = False, + write_timeout: float | None = None, + dsrdtr: bool = False, + inter_byte_timeout: float | None = None, + exclusive: bool | None = None, + ) -> None: + """Initialize the underlying `pyserial` `Serial` instance. + + Args: + baudrate: The baudrate of the serial connection. OK to ignore for + USB CDC ACM. + bytesize: The number of data bits. + parity: The parity setting. + stopbits: The number of stop bits. + timeout: The read timeout. + xonxoff: Enable software flow control. + rtscts: Enable hardware (RTS/CTS) flow control. + write_timeout: The write timeout. + dsrdtr: Enable hardware (DSR/DTR) flow control. + inter_byte_timeout: The inter-byte timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. + """ + self._conn: Final = Serial( + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + timeout=timeout, + xonxoff=xonxoff, + rtscts=rtscts, + write_timeout=write_timeout, + dsrdtr=dsrdtr, + inter_byte_timeout=inter_byte_timeout, + exclusive=exclusive, + ) + + def _reset_state(self) -> None: + """Reset any per-connection state. Subclasses override as needed.""" + + @final + @override + async def connect(self, address: str, timeout_s: float) -> None: + self._reset_state() + self._conn.port = address + logger.debug(f"Connecting to {self._conn.port=}") + start_time: Final = monotonic() + while monotonic() - start_time <= timeout_s: + try: + self._conn.open() + self._conn.reset_input_buffer() + logger.debug(f"Connected to {self._conn.port=}") + return + except SerialException as e: + logger.debug( + f"Failed to connect to {self._conn.port=}: {e}, " + f"retrying in {self._CONNECTION_RETRY_INTERVAL_S} seconds" + ) + await asyncio.sleep(self._CONNECTION_RETRY_INTERVAL_S) + + raise TimeoutError(f"Failed to connect to {address=}") + + @final + @override + async def disconnect(self) -> None: + logger.debug(f"Disconnecting from {self._conn.port=}") + self._conn.close() + logger.debug(f"Disconnected from {self._conn.port=}") + + @final + @override + async def send_and_receive(self, data: bytes) -> bytes: + await self.send(data) + return await self.receive() + + @final + @contextmanager + def _serial_exception_to_disconnected(self) -> Generator[None, None, None]: + """Translate `SerialException` from `pyserial` to `SMPTransportDisconnected`.""" + try: + yield + except SerialException as e: + logger.error(f"Serial exception on {self._conn.port}: {e}") + raise SMPTransportDisconnected( + f"{self.__class__.__name__} disconnected from {self._conn.port}" + ) from e + + @final + async def _drain_tx(self) -> None: + """Block until the serial TX buffer is empty. + + Fake-async polling until `pyserial` is replaced. + """ + while self._conn.out_waiting > 0: + await asyncio.sleep(self._POLLING_INTERVAL_S) + + @final + async def _read_all(self) -> bytes: + """Return all currently-available bytes (or empty bytes). + + Wraps `SerialException` into `SMPTransportDisconnected`. `StopIteration` is + caught to keep mocked `read_all` side-effect lists usable in tests. + """ + try: + return self._conn.read_all() or b"" + except StopIteration: + return b"" + except SerialException as exc: + raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") from exc diff --git a/src/smpclient/transport/serial.py b/src/smpclient/transport/serial/encoded.py similarity index 81% rename from src/smpclient/transport/serial.py rename to src/smpclient/transport/serial/encoded.py index 45e6924..fe5f38c 100644 --- a/src/smpclient/transport/serial.py +++ b/src/smpclient/transport/serial/encoded.py @@ -1,28 +1,27 @@ -"""A serial SMPTransport. +"""The base64-encoded serial SMPTransport. -In addition to UART, this transport can be used with USB CDC ACM and CAN. +Wraps each SMP packet in a length+CRC frame, base64-encodes it, and terminates +it with a newline. This is what Zephyr calls "SMP over console" -- the framing +shared by `CONFIG_MCUMGR_TRANSPORT_UART` and `CONFIG_MCUMGR_TRANSPORT_SHELL`, +and the only SMP-over-UART option that existed before Zephyr 4.4. Also the +only framing supported by MCUboot serial recovery (`MCUBOOT_SERIAL`), which +unconditionally selects `BASE64`. + +For `CONFIG_MCUMGR_TRANSPORT_RAW_UART` servers, use `SMPSerialRawTransport` +from `smpclient.transport.serial.unencoded`. """ import asyncio import logging import math -import time from enum import IntEnum, unique from functools import cached_property from typing import Final -try: - from serial import Serial, SerialException -except ModuleNotFoundError as e: - if e.name == "serial": - raise ImportError( - "Serial transport requires the 'serial' extra. Use smpclient[serial]" - ) from e - raise from smp import packet as smppacket from typing_extensions import override -from smpclient.transport import SMPTransport, SMPTransportDisconnected +from smpclient.transport.serial.common import _SerialTransportBase logger = logging.getLogger(__name__) @@ -43,10 +42,7 @@ def _base64_max(size: int) -> int: return math.floor(3 / 4 * size) - 2 -class SMPSerialTransport(SMPTransport): - _POLLING_INTERVAL_S = 0.005 - _CONNECTION_RETRY_INTERVAL_S = 0.500 - +class SMPSerialTransport(_SerialTransportBase): @unique class BufferState(IntEnum): SMP = 0 @@ -95,22 +91,12 @@ def __init__( # noqa: DOC301 write_timeout: The write timeout. dsrdtr: Enable hardware (DSR/DTR) flow control. inter_byte_timeout: The inter-byte timeout. - exclusive: The exclusive access timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. """ - if max_smp_encoded_frame_size < line_length * line_buffers: - logger.error( - f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!" - ) - elif max_smp_encoded_frame_size != line_length * line_buffers: - logger.warning( - f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!" - ) - - self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size - self._line_length: Final = line_length - self._line_buffers: Final = line_buffers - self._conn: Final = Serial( + super().__init__( baudrate=baudrate, bytesize=bytesize, parity=parity, @@ -124,6 +110,19 @@ def __init__( # noqa: DOC301 exclusive=exclusive, ) + if max_smp_encoded_frame_size < line_length * line_buffers: + logger.error( + f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!" + ) + elif max_smp_encoded_frame_size != line_length * line_buffers: + logger.warning( + f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!" + ) + + self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size + self._line_length: Final = line_length + self._line_buffers: Final = line_buffers + self._smp_packet_queue: asyncio.Queue[bytes] = asyncio.Queue() """Contains full SMP packets.""" self._serial_buffer = bytearray() @@ -135,6 +134,7 @@ def __init__( # noqa: DOC301 logger.debug(f"Initialized {self.__class__.__name__}") + @override def _reset_state(self) -> None: """Reset internal state and queues for a fresh connection.""" self._smp_packet_queue = asyncio.Queue() @@ -142,33 +142,6 @@ def _reset_state(self) -> None: self._buffer = bytearray([]) self._buffer_state = SMPSerialTransport.BufferState.SERIAL - @override - async def connect(self, address: str, timeout_s: float) -> None: - self._reset_state() - self._conn.port = address - logger.debug(f"Connecting to {self._conn.port=}") - start_time: Final = time.time() - while time.time() - start_time <= timeout_s: - try: - self._conn.open() - self._conn.reset_input_buffer() - logger.debug(f"Connected to {self._conn.port=}") - return - except SerialException as e: - logger.debug( - f"Failed to connect to {self._conn.port=}: {e}, " - f"retrying in {SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S} seconds" - ) - await asyncio.sleep(SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S) - - raise TimeoutError(f"Failed to connect to {address=}") - - @override - async def disconnect(self) -> None: - logger.debug(f"Disconnecting from {self._conn.port=}") - self._conn.close() - logger.debug(f"Disconnected from {self._conn.port=}") - @override async def send(self, data: bytes) -> None: if len(data) > self.max_unencoded_size: @@ -176,19 +149,12 @@ async def send(self, data: bytes) -> None: f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}" ) logger.debug(f"Sending {len(data)} bytes") - try: + with self._serial_exception_to_disconnected(): for packet in smppacket.encode(data, line_length=self._line_length): self._conn.write(packet) logger.debug(f"Writing encoded packet of size {len(packet)}B; {self._line_length=}") - # fake async until I get around to replacing pyserial - while self._conn.out_waiting > 0: - await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S) - except SerialException as e: - logger.error(f"Failed to send {len(data)} bytes: {e}") - raise SMPTransportDisconnected( - f"{self.__class__.__name__} disconnected from {self._conn.port}" - ) + await self._drain_tx() logger.debug(f"Sent {len(data)} bytes") @@ -242,18 +208,13 @@ async def read_serial(self, delimiter: bytes | None = None) -> bytes: async def _read_and_process(self, read_until_one_smp_packet: bool) -> None: """Reads raw data from serial and processes it into SMP packets and regular serial data.""" while True: - try: - data = self._conn.read_all() or b"" - except StopIteration: - data = b"" - except SerialException as exc: - raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") + data = await self._read_all() if data: self._buffer.extend(data) await self._process_buffer() else: - await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S) + await asyncio.sleep(self._POLLING_INTERVAL_S) if read_until_one_smp_packet: if self._smp_packet_queue.qsize(): @@ -342,11 +303,6 @@ def _could_be_smp_packet_start(self, byte: int) -> bool: """Return True if the given byte value matches the start of any SMP packet delimiter.""" return byte == smppacket.START_DELIMITER[0] or byte == smppacket.CONTINUE_DELIMITER[0] - @override - async def send_and_receive(self, data: bytes) -> bytes: - await self.send(data) - return await self.receive() - @override @property def mtu(self) -> int: diff --git a/src/smpclient/transport/serial/unencoded.py b/src/smpclient/transport/serial/unencoded.py new file mode 100644 index 0000000..05dcf0c --- /dev/null +++ b/src/smpclient/transport/serial/unencoded.py @@ -0,0 +1,138 @@ +"""The unencoded (raw) serial SMPTransport. + +This is the Zephyr "raw UART" SMP transport, enabled on the server by +`CONFIG_MCUMGR_TRANSPORT_RAW_UART` together with `CONFIG_UART_MCUMGR_RAW_PROTOCOL`. +Each SMP message is sent over the wire as the raw bytes +`[8-byte SMP header][header.length bytes of payload]` with no framing, encoding, +or CRC. The receiver parses the SMP header to determine the message length. + +This transport cannot coexist with shell or log output on the same UART. If +you need shell interleaving, use `SMPSerialTransport` from +`smpclient.transport.serial.encoded`. +""" + +import asyncio +import logging +from typing import Final + +from smp import header as smphdr +from typing_extensions import override + +from smpclient.exceptions import SMPClientException +from smpclient.transport.serial.common import _SerialTransportBase + +logger = logging.getLogger(__name__) + + +class SMPSerialRawTransport(_SerialTransportBase): + def __init__( + self, + mtu: int = 384, + baudrate: int = 115200, + bytesize: int = 8, + parity: str = "N", + stopbits: float = 1, + timeout: float | None = None, + xonxoff: bool = False, + rtscts: bool = False, + write_timeout: float | None = None, + dsrdtr: bool = False, + inter_byte_timeout: float | None = None, + exclusive: bool | None = None, + ) -> None: + """Initialize the raw serial transport. + + Args: + mtu: The maximum size of one SMP message (header + payload), in + bytes. A serial link has no MTU of its own, but the SMP + server's receive buffer does -- this should match the server's + `CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE` (Zephyr default 384). + baudrate: The baudrate of the serial connection. OK to ignore for + USB CDC ACM. + bytesize: The number of data bits. + parity: The parity setting. + stopbits: The number of stop bits. + timeout: The read timeout. + xonxoff: Enable software flow control. + rtscts: Enable hardware (RTS/CTS) flow control. + write_timeout: The write timeout. + dsrdtr: Enable hardware (DSR/DTR) flow control. + inter_byte_timeout: The inter-byte timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. + """ + super().__init__( + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + timeout=timeout, + xonxoff=xonxoff, + rtscts=rtscts, + write_timeout=write_timeout, + dsrdtr=dsrdtr, + inter_byte_timeout=inter_byte_timeout, + exclusive=exclusive, + ) + self._mtu: Final = mtu + + logger.debug(f"Initialized {self.__class__.__name__}") + + @override + async def send(self, data: bytes) -> None: + if len(data) > self.max_unencoded_size: + raise ValueError( + f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}" + ) + logger.debug(f"Sending {len(data)} bytes") + with self._serial_exception_to_disconnected(): + self._conn.write(data) + await self._drain_tx() + logger.debug(f"Sent {len(data)} bytes") + + @override + async def receive(self) -> bytes: + logger.debug("Waiting for response") + message = bytearray() + + while len(message) < smphdr.Header.SIZE: + await self._poll_read_into(message) + + header: Final = smphdr.Header.loads(bytes(message[: smphdr.Header.SIZE])) + message_length: Final = header.length + smphdr.Header.SIZE + logger.debug(f"Received {header=}; awaiting {message_length} B total") + + # The header's length field is attacker/noise-controlled - bound it before + # we start waiting for that many bytes to arrive. + if message_length > self.max_unencoded_size: + error = ( + f"Header claims a {message_length} B message, " + f"exceeding max_unencoded_size={self.max_unencoded_size}" + ) + logger.error(error) + raise SMPClientException(error) + + while len(message) < message_length: + await self._poll_read_into(message) + + if len(message) > message_length: + error = f"Received more data than expected: {len(message)} B > {message_length} B" + logger.error(error) + raise SMPClientException(error) + + logger.debug(f"Finished receiving {message_length} B response") + return bytes(message) + + async def _poll_read_into(self, buf: bytearray) -> None: + """Read available bytes into `buf`; if none, yield via a short sleep.""" + data = await self._read_all() + if data: + buf.extend(data) + else: + await asyncio.sleep(self._POLLING_INTERVAL_S) + + @override + @property + def mtu(self) -> int: + return self._mtu diff --git a/tests/test_base64.py b/tests/test_base64.py index e1f600f..6f26ca9 100644 --- a/tests/test_base64.py +++ b/tests/test_base64.py @@ -3,7 +3,7 @@ import random from base64 import b64encode -from smpclient.transport.serial import _base64_cost, _base64_max +from smpclient.transport.serial.encoded import _base64_cost, _base64_max if not hasattr(random, 'randbytes'): from os import urandom diff --git a/tests/test_smp_client.py b/tests/test_smp_client.py index 20adc81..d94a0c5 100644 --- a/tests/test_smp_client.py +++ b/tests/test_smp_client.py @@ -36,7 +36,7 @@ from smpclient.requests.file_management import FileDownload, FileUpload from smpclient.requests.image_management import ImageUploadWrite from smpclient.requests.os_management import ResetWrite -from smpclient.transport.serial import SMPSerialTransport +from smpclient.transport.serial import SMPSerialRawTransport, SMPSerialTransport class SMPMockTransport: @@ -365,6 +365,53 @@ async def mock_request( assert reconstructed_image == image +@pytest.mark.asyncio +@pytest.mark.parametrize("mtu", [128, 256, 512, 1024, 2048, 4096, 8192]) +async def test_upload_hello_world_bin_raw(mtu: int) -> None: + with open( + str(Path("tests", "fixtures", "zephyr-v3.5.0-2795-g28ff83515d", "hello_world.signed.bin")), + 'rb', + ) as f: + image = f.read() + + m = SMPSerialRawTransport(mtu=mtu) + s = SMPClient(m, "address") + assert s._transport.mtu == mtu + assert s._transport.max_unencoded_size == mtu, "The raw transport has no encoding overhead" + + packets: list[bytes] = [] + + def mock_write(data: bytes) -> int: + """Accumulate the raw packets in the global `packets`.""" + packets.append(data) + return len(data) + + s._transport._conn.write = mock_write # type: ignore + + async def mock_request( + request: ImageUploadWrite, timeout_s: float = 120.000 + ) -> ImageUploadWriteResponse: + # call the real send method (with write mocked) but don't bother with receive + # this provides coverage for the MTU-limited chunking done by SMPClient.upload + await s._transport.send(request.BYTES) + return ImageUploadWrite._Response.get_default()(off=request.off + len(request.data)) # type: ignore # noqa + + s.request = mock_request # type: ignore + + # `out_waiting` is a property on the real Serial class - scope the patch so it + # restores cleanly when the test finishes. + with patch.object(type(s._transport._conn), 'out_waiting', 0): # type: ignore + async for _ in s.upload(image): + pass + + # Each captured write is one complete SMP message [header][payload], no decoding needed. + reconstructed_image = bytearray([]) + for packet in packets: + reconstructed_image.extend(ImageUploadWriteRequest.loads(packet).data) + + assert reconstructed_image == image + + @pytest.mark.asyncio async def test_upload_file() -> None: m = SMPMockTransport() diff --git a/tests/test_smp_serial_raw_transport.py b/tests/test_smp_serial_raw_transport.py new file mode 100644 index 0000000..53ef77e --- /dev/null +++ b/tests/test_smp_serial_raw_transport.py @@ -0,0 +1,254 @@ +"""Tests for `SMPSerialRawTransport`.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest +from serial import SerialException +from smp import header as smphdr + +from smpclient.exceptions import SMPClientException +from smpclient.requests.os_management import EchoWrite +from smpclient.transport import SMPTransportDisconnected +from smpclient.transport.serial import SMPSerialRawTransport + + +@pytest.fixture(autouse=True) +def mock_serial() -> Generator[None, Any, None]: + with patch("smpclient.transport.serial.common.Serial"): + yield + + +def test_constructor() -> None: + t = SMPSerialRawTransport(mtu=512) + assert t.mtu == 512 + assert t.max_unencoded_size == 512 + + +def test_constructor_defaults() -> None: + t = SMPSerialRawTransport() + assert t.mtu == 384 + + +@pytest.mark.asyncio +async def test_connect_disconnect() -> None: + ports: list[str] = ["COM2", "/dev/ttyACM0", "/dev/ttyUSB0"] + + t = SMPSerialRawTransport() + t._conn.read_all = MagicMock(return_value=b"") # type: ignore + + for p in ports: + await asyncio.wait_for(t.connect(p, 1.0), timeout=1.0) + t._conn.open.assert_called_once() # type: ignore + + assert t._conn.port == p + + await asyncio.wait_for(t.disconnect(), timeout=0.1) + t._conn.close.assert_called_once() # type: ignore + + t._conn.reset_mock() # type: ignore + + +@pytest.mark.asyncio +async def test_connect_retries_until_timeout() -> None: + t = SMPSerialRawTransport() + t._conn.open = MagicMock(side_effect=SerialException("nope")) # type: ignore + + with pytest.raises(TimeoutError): + await asyncio.wait_for(t.connect("/dev/ttyUSB0", 0.1), timeout=2.0) + + +@pytest.mark.asyncio +async def test_send() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock() # type: ignore + p = PropertyMock(return_value=0) + type(t._conn).out_waiting = p # type: ignore + + r = EchoWrite(d="Hello pytest!") + await t.send(r.BYTES) + + # Raw transport writes the bytes verbatim - no encoding. + t._conn.write.assert_called_once_with(r.BYTES) + p.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_send_waits_for_tx_drain() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock() # type: ignore + p = PropertyMock(side_effect=(1, 0)) + type(t._conn).out_waiting = p # type: ignore + + await t.send(EchoWrite(d="x").BYTES) + assert p.call_count == 2 + + +@pytest.mark.asyncio +async def test_send_too_large_raises() -> None: + t = SMPSerialRawTransport(mtu=16) + with pytest.raises(ValueError): + await t.send(b"\x00" * 32) + + +@pytest.mark.asyncio +async def test_send_disconnected_raises() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock(side_effect=SerialException("disconnected")) # type: ignore + + with pytest.raises(SMPTransportDisconnected): + await t.send(EchoWrite(d="x").BYTES) + + +@pytest.mark.asyncio +async def test_receive_single_packet() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello pytest!") # type: ignore + t._conn.read_all = MagicMock(side_effect=[m.BYTES]) # type: ignore + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_fragmented() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello pytest!") # type: ignore + fragments = [ + m.BYTES[:3], # less than a header + m.BYTES[3:8], # completes the header but no payload yet + m.BYTES[8:10], + m.BYTES[10:], # rest of payload + ] + t._conn.read_all = MagicMock(side_effect=fragments) # type: ignore + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_byte_at_a_time() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hi") # type: ignore + t._conn.read_all = MagicMock( # type: ignore + side_effect=[bytes([b]) for b in m.BYTES] + ) + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_consecutive_messages() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m1 = EchoWrite._Response.get_default()(sequence=0, r="SMP Message 1") # type: ignore + m2 = EchoWrite._Response.get_default()(sequence=1, r="SMP Message 2") # type: ignore + m3 = EchoWrite._Response.get_default()(sequence=2, r="SMP Message 3") # type: ignore + + # Each receive() reads one full message, just like a normal request/response loop. + t._conn.read_all = MagicMock(side_effect=[m1.BYTES, m2.BYTES, m3.BYTES]) # type: ignore + + assert await t.receive() == m1.BYTES + assert await t.receive() == m2.BYTES + assert await t.receive() == m3.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_overrun_raises() -> None: + """A single read returning more bytes than the header advertises is an error. + + SMP is strictly request/response; the server should never send unsolicited bytes. + """ + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello!") # type: ignore + t._conn.read_all = MagicMock(side_effect=[m.BYTES + b"\x00\x01\x02"]) # type: ignore + + with pytest.raises(SMPClientException): + await t.receive() + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_polls_when_nothing_available() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="ok") # type: ignore + t._conn.read_all = MagicMock(side_effect=[b"", b"", m.BYTES]) # type: ignore + + received = await t.receive() + assert received == m.BYTES + assert t._conn.read_all.call_count >= 3 + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_oversized_header_raises() -> None: + """A header claiming more bytes than max_unencoded_size is rejected. + + Defensive bound against noisy or corrupted UART traffic that would + otherwise cause an unbounded wait. + """ + t = SMPSerialRawTransport(mtu=64) + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + bogus_header = smphdr.Header( + op=smphdr.OP.WRITE_RSP, + version=smphdr.Version.V2, + flags=smphdr.Flag(0), + length=10_000, + group_id=smphdr.GroupId.OS_MANAGEMENT, + sequence=0, + command_id=smphdr.CommandId.OSManagement.ECHO, + ).BYTES + t._conn.read_all = MagicMock(side_effect=[bogus_header]) # type: ignore + + with pytest.raises(SMPClientException): + await t.receive() + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_disconnected_raises() -> None: + t = SMPSerialRawTransport() + t._conn.read_all = MagicMock(side_effect=SerialException("disconnected")) # type: ignore + + with pytest.raises(SMPTransportDisconnected): + await t.receive() + + +@pytest.mark.asyncio +async def test_send_and_receive() -> None: + t = SMPSerialRawTransport() + t.send = AsyncMock() # type: ignore + t.receive = AsyncMock() # type: ignore + + await t.send_and_receive(b"some data") + + t.send.assert_awaited_once_with(b"some data") + t.receive.assert_awaited_once_with() diff --git a/tests/test_smp_serial_transport.py b/tests/test_smp_serial_transport.py index df64dab..436051f 100644 --- a/tests/test_smp_serial_transport.py +++ b/tests/test_smp_serial_transport.py @@ -18,7 +18,7 @@ @pytest.fixture(autouse=True) def mock_serial() -> Generator[None, Any, None]: - with patch("smpclient.transport.serial.Serial"): + with patch("smpclient.transport.serial.common.Serial"): yield