From d82c25f7ef74337d46083496f90174ad325f19d3 Mon Sep 17 00:00:00 2001 From: JP Hutchins Date: Wed, 20 May 2026 17:18:43 -0700 Subject: [PATCH] feat: add SMPSerialRawTransport for Zephyr raw UART SMP Zephyr 4.4 introduced CONFIG_MCUMGR_TRANSPORT_RAW_UART -- an SMP-over-UART transport that sends each SMP message as raw [8-byte header][payload] bytes with no framing, base64, or CRC. Smaller code size and faster transfers than the historical "SMP over console" framing, at the cost of being unable to share the UART with shell or log output. Splits the single transport/serial.py into a package so the two transports share connection management via a final-method base class: serial/ __init__.py re-exports both transports common.py _SerialTransportBase: pyserial holder, connect/disconnect, SerialException -> SMPTransportDisconnected translation, TX drain, RX polling helper encoded.py SMPSerialTransport (unchanged: base64 + CRC + delimiters, shell interleave via read_serial) unencoded.py SMPSerialRawTransport (new: writes bytes verbatim, reads SMP header then payload, raises on overrun and on header lengths exceeding max_unencoded_size) The from smpclient.transport.serial import SMPSerialTransport import path is preserved -- existing users (including MCUboot serial recovery, which still requires the base64 framing) need no changes. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/smpclient/transport/serial/__init__.py | 7 + src/smpclient/transport/serial/common.py | 155 +++++++++++ .../{serial.py => serial/encoded.py} | 112 +++----- src/smpclient/transport/serial/unencoded.py | 138 ++++++++++ tests/test_base64.py | 2 +- tests/test_smp_client.py | 49 +++- tests/test_smp_serial_raw_transport.py | 254 ++++++++++++++++++ tests/test_smp_serial_transport.py | 2 +- 8 files changed, 638 insertions(+), 81 deletions(-) create mode 100644 src/smpclient/transport/serial/__init__.py create mode 100644 src/smpclient/transport/serial/common.py rename src/smpclient/transport/{serial.py => serial/encoded.py} (81%) create mode 100644 src/smpclient/transport/serial/unencoded.py create mode 100644 tests/test_smp_serial_raw_transport.py diff --git a/src/smpclient/transport/serial/__init__.py b/src/smpclient/transport/serial/__init__.py new file mode 100644 index 0000000..bd60495 --- /dev/null +++ b/src/smpclient/transport/serial/__init__.py @@ -0,0 +1,7 @@ +"""Serial SMPTransports. + +In addition to UART, these transports can be used with USB CDC ACM and CAN. +""" + +from smpclient.transport.serial.encoded import SMPSerialTransport as SMPSerialTransport +from smpclient.transport.serial.unencoded import SMPSerialRawTransport as SMPSerialRawTransport diff --git a/src/smpclient/transport/serial/common.py b/src/smpclient/transport/serial/common.py new file mode 100644 index 0000000..b26cd68 --- /dev/null +++ b/src/smpclient/transport/serial/common.py @@ -0,0 +1,155 @@ +"""Shared connection management for the encoded and unencoded serial transports.""" + +import asyncio +import logging +from contextlib import contextmanager +from time import monotonic +from typing import Final, Generator, final + +try: + from serial import Serial, SerialException +except ModuleNotFoundError as e: + if e.name == "serial": + raise ImportError( + "Serial transport requires the 'serial' extra. Use smpclient[serial]" + ) from e + raise +from typing_extensions import override + +from smpclient.transport import SMPTransport, SMPTransportDisconnected + +logger = logging.getLogger(__name__) + + +class _SerialTransportBase(SMPTransport): + """Connection-management base class for serial-port-backed SMP transports. + + Holds the `pyserial` `Serial` instance, the open/retry connect loop, disconnect, + and the small TX/RX helpers that wrap `SerialException` into + `SMPTransportDisconnected`. + + Subclasses implement `send` and `receive` with their framing of choice and may + override `_reset_state` to clear per-connection state on `connect`. + """ + + _POLLING_INTERVAL_S: Final = 0.005 + _CONNECTION_RETRY_INTERVAL_S: Final = 0.500 + + def __init__( + self, + baudrate: int = 115200, + bytesize: int = 8, + parity: str = "N", + stopbits: float = 1, + timeout: float | None = None, + xonxoff: bool = False, + rtscts: bool = False, + write_timeout: float | None = None, + dsrdtr: bool = False, + inter_byte_timeout: float | None = None, + exclusive: bool | None = None, + ) -> None: + """Initialize the underlying `pyserial` `Serial` instance. + + Args: + baudrate: The baudrate of the serial connection. OK to ignore for + USB CDC ACM. + bytesize: The number of data bits. + parity: The parity setting. + stopbits: The number of stop bits. + timeout: The read timeout. + xonxoff: Enable software flow control. + rtscts: Enable hardware (RTS/CTS) flow control. + write_timeout: The write timeout. + dsrdtr: Enable hardware (DSR/DTR) flow control. + inter_byte_timeout: The inter-byte timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. + """ + self._conn: Final = Serial( + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + timeout=timeout, + xonxoff=xonxoff, + rtscts=rtscts, + write_timeout=write_timeout, + dsrdtr=dsrdtr, + inter_byte_timeout=inter_byte_timeout, + exclusive=exclusive, + ) + + def _reset_state(self) -> None: + """Reset any per-connection state. Subclasses override as needed.""" + + @final + @override + async def connect(self, address: str, timeout_s: float) -> None: + self._reset_state() + self._conn.port = address + logger.debug(f"Connecting to {self._conn.port=}") + start_time: Final = monotonic() + while monotonic() - start_time <= timeout_s: + try: + self._conn.open() + self._conn.reset_input_buffer() + logger.debug(f"Connected to {self._conn.port=}") + return + except SerialException as e: + logger.debug( + f"Failed to connect to {self._conn.port=}: {e}, " + f"retrying in {self._CONNECTION_RETRY_INTERVAL_S} seconds" + ) + await asyncio.sleep(self._CONNECTION_RETRY_INTERVAL_S) + + raise TimeoutError(f"Failed to connect to {address=}") + + @final + @override + async def disconnect(self) -> None: + logger.debug(f"Disconnecting from {self._conn.port=}") + self._conn.close() + logger.debug(f"Disconnected from {self._conn.port=}") + + @final + @override + async def send_and_receive(self, data: bytes) -> bytes: + await self.send(data) + return await self.receive() + + @final + @contextmanager + def _serial_exception_to_disconnected(self) -> Generator[None, None, None]: + """Translate `SerialException` from `pyserial` to `SMPTransportDisconnected`.""" + try: + yield + except SerialException as e: + logger.error(f"Serial exception on {self._conn.port}: {e}") + raise SMPTransportDisconnected( + f"{self.__class__.__name__} disconnected from {self._conn.port}" + ) from e + + @final + async def _drain_tx(self) -> None: + """Block until the serial TX buffer is empty. + + Fake-async polling until `pyserial` is replaced. + """ + while self._conn.out_waiting > 0: + await asyncio.sleep(self._POLLING_INTERVAL_S) + + @final + async def _read_all(self) -> bytes: + """Return all currently-available bytes (or empty bytes). + + Wraps `SerialException` into `SMPTransportDisconnected`. `StopIteration` is + caught to keep mocked `read_all` side-effect lists usable in tests. + """ + try: + return self._conn.read_all() or b"" + except StopIteration: + return b"" + except SerialException as exc: + raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") from exc diff --git a/src/smpclient/transport/serial.py b/src/smpclient/transport/serial/encoded.py similarity index 81% rename from src/smpclient/transport/serial.py rename to src/smpclient/transport/serial/encoded.py index 45e6924..fe5f38c 100644 --- a/src/smpclient/transport/serial.py +++ b/src/smpclient/transport/serial/encoded.py @@ -1,28 +1,27 @@ -"""A serial SMPTransport. +"""The base64-encoded serial SMPTransport. -In addition to UART, this transport can be used with USB CDC ACM and CAN. +Wraps each SMP packet in a length+CRC frame, base64-encodes it, and terminates +it with a newline. This is what Zephyr calls "SMP over console" -- the framing +shared by `CONFIG_MCUMGR_TRANSPORT_UART` and `CONFIG_MCUMGR_TRANSPORT_SHELL`, +and the only SMP-over-UART option that existed before Zephyr 4.4. Also the +only framing supported by MCUboot serial recovery (`MCUBOOT_SERIAL`), which +unconditionally selects `BASE64`. + +For `CONFIG_MCUMGR_TRANSPORT_RAW_UART` servers, use `SMPSerialRawTransport` +from `smpclient.transport.serial.unencoded`. """ import asyncio import logging import math -import time from enum import IntEnum, unique from functools import cached_property from typing import Final -try: - from serial import Serial, SerialException -except ModuleNotFoundError as e: - if e.name == "serial": - raise ImportError( - "Serial transport requires the 'serial' extra. Use smpclient[serial]" - ) from e - raise from smp import packet as smppacket from typing_extensions import override -from smpclient.transport import SMPTransport, SMPTransportDisconnected +from smpclient.transport.serial.common import _SerialTransportBase logger = logging.getLogger(__name__) @@ -43,10 +42,7 @@ def _base64_max(size: int) -> int: return math.floor(3 / 4 * size) - 2 -class SMPSerialTransport(SMPTransport): - _POLLING_INTERVAL_S = 0.005 - _CONNECTION_RETRY_INTERVAL_S = 0.500 - +class SMPSerialTransport(_SerialTransportBase): @unique class BufferState(IntEnum): SMP = 0 @@ -95,22 +91,12 @@ def __init__( # noqa: DOC301 write_timeout: The write timeout. dsrdtr: Enable hardware (DSR/DTR) flow control. inter_byte_timeout: The inter-byte timeout. - exclusive: The exclusive access timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. """ - if max_smp_encoded_frame_size < line_length * line_buffers: - logger.error( - f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!" - ) - elif max_smp_encoded_frame_size != line_length * line_buffers: - logger.warning( - f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!" - ) - - self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size - self._line_length: Final = line_length - self._line_buffers: Final = line_buffers - self._conn: Final = Serial( + super().__init__( baudrate=baudrate, bytesize=bytesize, parity=parity, @@ -124,6 +110,19 @@ def __init__( # noqa: DOC301 exclusive=exclusive, ) + if max_smp_encoded_frame_size < line_length * line_buffers: + logger.error( + f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!" + ) + elif max_smp_encoded_frame_size != line_length * line_buffers: + logger.warning( + f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!" + ) + + self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size + self._line_length: Final = line_length + self._line_buffers: Final = line_buffers + self._smp_packet_queue: asyncio.Queue[bytes] = asyncio.Queue() """Contains full SMP packets.""" self._serial_buffer = bytearray() @@ -135,6 +134,7 @@ def __init__( # noqa: DOC301 logger.debug(f"Initialized {self.__class__.__name__}") + @override def _reset_state(self) -> None: """Reset internal state and queues for a fresh connection.""" self._smp_packet_queue = asyncio.Queue() @@ -142,33 +142,6 @@ def _reset_state(self) -> None: self._buffer = bytearray([]) self._buffer_state = SMPSerialTransport.BufferState.SERIAL - @override - async def connect(self, address: str, timeout_s: float) -> None: - self._reset_state() - self._conn.port = address - logger.debug(f"Connecting to {self._conn.port=}") - start_time: Final = time.time() - while time.time() - start_time <= timeout_s: - try: - self._conn.open() - self._conn.reset_input_buffer() - logger.debug(f"Connected to {self._conn.port=}") - return - except SerialException as e: - logger.debug( - f"Failed to connect to {self._conn.port=}: {e}, " - f"retrying in {SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S} seconds" - ) - await asyncio.sleep(SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S) - - raise TimeoutError(f"Failed to connect to {address=}") - - @override - async def disconnect(self) -> None: - logger.debug(f"Disconnecting from {self._conn.port=}") - self._conn.close() - logger.debug(f"Disconnected from {self._conn.port=}") - @override async def send(self, data: bytes) -> None: if len(data) > self.max_unencoded_size: @@ -176,19 +149,12 @@ async def send(self, data: bytes) -> None: f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}" ) logger.debug(f"Sending {len(data)} bytes") - try: + with self._serial_exception_to_disconnected(): for packet in smppacket.encode(data, line_length=self._line_length): self._conn.write(packet) logger.debug(f"Writing encoded packet of size {len(packet)}B; {self._line_length=}") - # fake async until I get around to replacing pyserial - while self._conn.out_waiting > 0: - await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S) - except SerialException as e: - logger.error(f"Failed to send {len(data)} bytes: {e}") - raise SMPTransportDisconnected( - f"{self.__class__.__name__} disconnected from {self._conn.port}" - ) + await self._drain_tx() logger.debug(f"Sent {len(data)} bytes") @@ -242,18 +208,13 @@ async def read_serial(self, delimiter: bytes | None = None) -> bytes: async def _read_and_process(self, read_until_one_smp_packet: bool) -> None: """Reads raw data from serial and processes it into SMP packets and regular serial data.""" while True: - try: - data = self._conn.read_all() or b"" - except StopIteration: - data = b"" - except SerialException as exc: - raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") + data = await self._read_all() if data: self._buffer.extend(data) await self._process_buffer() else: - await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S) + await asyncio.sleep(self._POLLING_INTERVAL_S) if read_until_one_smp_packet: if self._smp_packet_queue.qsize(): @@ -342,11 +303,6 @@ def _could_be_smp_packet_start(self, byte: int) -> bool: """Return True if the given byte value matches the start of any SMP packet delimiter.""" return byte == smppacket.START_DELIMITER[0] or byte == smppacket.CONTINUE_DELIMITER[0] - @override - async def send_and_receive(self, data: bytes) -> bytes: - await self.send(data) - return await self.receive() - @override @property def mtu(self) -> int: diff --git a/src/smpclient/transport/serial/unencoded.py b/src/smpclient/transport/serial/unencoded.py new file mode 100644 index 0000000..05dcf0c --- /dev/null +++ b/src/smpclient/transport/serial/unencoded.py @@ -0,0 +1,138 @@ +"""The unencoded (raw) serial SMPTransport. + +This is the Zephyr "raw UART" SMP transport, enabled on the server by +`CONFIG_MCUMGR_TRANSPORT_RAW_UART` together with `CONFIG_UART_MCUMGR_RAW_PROTOCOL`. +Each SMP message is sent over the wire as the raw bytes +`[8-byte SMP header][header.length bytes of payload]` with no framing, encoding, +or CRC. The receiver parses the SMP header to determine the message length. + +This transport cannot coexist with shell or log output on the same UART. If +you need shell interleaving, use `SMPSerialTransport` from +`smpclient.transport.serial.encoded`. +""" + +import asyncio +import logging +from typing import Final + +from smp import header as smphdr +from typing_extensions import override + +from smpclient.exceptions import SMPClientException +from smpclient.transport.serial.common import _SerialTransportBase + +logger = logging.getLogger(__name__) + + +class SMPSerialRawTransport(_SerialTransportBase): + def __init__( + self, + mtu: int = 384, + baudrate: int = 115200, + bytesize: int = 8, + parity: str = "N", + stopbits: float = 1, + timeout: float | None = None, + xonxoff: bool = False, + rtscts: bool = False, + write_timeout: float | None = None, + dsrdtr: bool = False, + inter_byte_timeout: float | None = None, + exclusive: bool | None = None, + ) -> None: + """Initialize the raw serial transport. + + Args: + mtu: The maximum size of one SMP message (header + payload), in + bytes. A serial link has no MTU of its own, but the SMP + server's receive buffer does -- this should match the server's + `CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE` (Zephyr default 384). + baudrate: The baudrate of the serial connection. OK to ignore for + USB CDC ACM. + bytesize: The number of data bits. + parity: The parity setting. + stopbits: The number of stop bits. + timeout: The read timeout. + xonxoff: Enable software flow control. + rtscts: Enable hardware (RTS/CTS) flow control. + write_timeout: The write timeout. + dsrdtr: Enable hardware (DSR/DTR) flow control. + inter_byte_timeout: The inter-byte timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. + """ + super().__init__( + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + timeout=timeout, + xonxoff=xonxoff, + rtscts=rtscts, + write_timeout=write_timeout, + dsrdtr=dsrdtr, + inter_byte_timeout=inter_byte_timeout, + exclusive=exclusive, + ) + self._mtu: Final = mtu + + logger.debug(f"Initialized {self.__class__.__name__}") + + @override + async def send(self, data: bytes) -> None: + if len(data) > self.max_unencoded_size: + raise ValueError( + f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}" + ) + logger.debug(f"Sending {len(data)} bytes") + with self._serial_exception_to_disconnected(): + self._conn.write(data) + await self._drain_tx() + logger.debug(f"Sent {len(data)} bytes") + + @override + async def receive(self) -> bytes: + logger.debug("Waiting for response") + message = bytearray() + + while len(message) < smphdr.Header.SIZE: + await self._poll_read_into(message) + + header: Final = smphdr.Header.loads(bytes(message[: smphdr.Header.SIZE])) + message_length: Final = header.length + smphdr.Header.SIZE + logger.debug(f"Received {header=}; awaiting {message_length} B total") + + # The header's length field is attacker/noise-controlled - bound it before + # we start waiting for that many bytes to arrive. + if message_length > self.max_unencoded_size: + error = ( + f"Header claims a {message_length} B message, " + f"exceeding max_unencoded_size={self.max_unencoded_size}" + ) + logger.error(error) + raise SMPClientException(error) + + while len(message) < message_length: + await self._poll_read_into(message) + + if len(message) > message_length: + error = f"Received more data than expected: {len(message)} B > {message_length} B" + logger.error(error) + raise SMPClientException(error) + + logger.debug(f"Finished receiving {message_length} B response") + return bytes(message) + + async def _poll_read_into(self, buf: bytearray) -> None: + """Read available bytes into `buf`; if none, yield via a short sleep.""" + data = await self._read_all() + if data: + buf.extend(data) + else: + await asyncio.sleep(self._POLLING_INTERVAL_S) + + @override + @property + def mtu(self) -> int: + return self._mtu diff --git a/tests/test_base64.py b/tests/test_base64.py index e1f600f..6f26ca9 100644 --- a/tests/test_base64.py +++ b/tests/test_base64.py @@ -3,7 +3,7 @@ import random from base64 import b64encode -from smpclient.transport.serial import _base64_cost, _base64_max +from smpclient.transport.serial.encoded import _base64_cost, _base64_max if not hasattr(random, 'randbytes'): from os import urandom diff --git a/tests/test_smp_client.py b/tests/test_smp_client.py index 20adc81..d94a0c5 100644 --- a/tests/test_smp_client.py +++ b/tests/test_smp_client.py @@ -36,7 +36,7 @@ from smpclient.requests.file_management import FileDownload, FileUpload from smpclient.requests.image_management import ImageUploadWrite from smpclient.requests.os_management import ResetWrite -from smpclient.transport.serial import SMPSerialTransport +from smpclient.transport.serial import SMPSerialRawTransport, SMPSerialTransport class SMPMockTransport: @@ -365,6 +365,53 @@ async def mock_request( assert reconstructed_image == image +@pytest.mark.asyncio +@pytest.mark.parametrize("mtu", [128, 256, 512, 1024, 2048, 4096, 8192]) +async def test_upload_hello_world_bin_raw(mtu: int) -> None: + with open( + str(Path("tests", "fixtures", "zephyr-v3.5.0-2795-g28ff83515d", "hello_world.signed.bin")), + 'rb', + ) as f: + image = f.read() + + m = SMPSerialRawTransport(mtu=mtu) + s = SMPClient(m, "address") + assert s._transport.mtu == mtu + assert s._transport.max_unencoded_size == mtu, "The raw transport has no encoding overhead" + + packets: list[bytes] = [] + + def mock_write(data: bytes) -> int: + """Accumulate the raw packets in the global `packets`.""" + packets.append(data) + return len(data) + + s._transport._conn.write = mock_write # type: ignore + + async def mock_request( + request: ImageUploadWrite, timeout_s: float = 120.000 + ) -> ImageUploadWriteResponse: + # call the real send method (with write mocked) but don't bother with receive + # this provides coverage for the MTU-limited chunking done by SMPClient.upload + await s._transport.send(request.BYTES) + return ImageUploadWrite._Response.get_default()(off=request.off + len(request.data)) # type: ignore # noqa + + s.request = mock_request # type: ignore + + # `out_waiting` is a property on the real Serial class - scope the patch so it + # restores cleanly when the test finishes. + with patch.object(type(s._transport._conn), 'out_waiting', 0): # type: ignore + async for _ in s.upload(image): + pass + + # Each captured write is one complete SMP message [header][payload], no decoding needed. + reconstructed_image = bytearray([]) + for packet in packets: + reconstructed_image.extend(ImageUploadWriteRequest.loads(packet).data) + + assert reconstructed_image == image + + @pytest.mark.asyncio async def test_upload_file() -> None: m = SMPMockTransport() diff --git a/tests/test_smp_serial_raw_transport.py b/tests/test_smp_serial_raw_transport.py new file mode 100644 index 0000000..53ef77e --- /dev/null +++ b/tests/test_smp_serial_raw_transport.py @@ -0,0 +1,254 @@ +"""Tests for `SMPSerialRawTransport`.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest +from serial import SerialException +from smp import header as smphdr + +from smpclient.exceptions import SMPClientException +from smpclient.requests.os_management import EchoWrite +from smpclient.transport import SMPTransportDisconnected +from smpclient.transport.serial import SMPSerialRawTransport + + +@pytest.fixture(autouse=True) +def mock_serial() -> Generator[None, Any, None]: + with patch("smpclient.transport.serial.common.Serial"): + yield + + +def test_constructor() -> None: + t = SMPSerialRawTransport(mtu=512) + assert t.mtu == 512 + assert t.max_unencoded_size == 512 + + +def test_constructor_defaults() -> None: + t = SMPSerialRawTransport() + assert t.mtu == 384 + + +@pytest.mark.asyncio +async def test_connect_disconnect() -> None: + ports: list[str] = ["COM2", "/dev/ttyACM0", "/dev/ttyUSB0"] + + t = SMPSerialRawTransport() + t._conn.read_all = MagicMock(return_value=b"") # type: ignore + + for p in ports: + await asyncio.wait_for(t.connect(p, 1.0), timeout=1.0) + t._conn.open.assert_called_once() # type: ignore + + assert t._conn.port == p + + await asyncio.wait_for(t.disconnect(), timeout=0.1) + t._conn.close.assert_called_once() # type: ignore + + t._conn.reset_mock() # type: ignore + + +@pytest.mark.asyncio +async def test_connect_retries_until_timeout() -> None: + t = SMPSerialRawTransport() + t._conn.open = MagicMock(side_effect=SerialException("nope")) # type: ignore + + with pytest.raises(TimeoutError): + await asyncio.wait_for(t.connect("/dev/ttyUSB0", 0.1), timeout=2.0) + + +@pytest.mark.asyncio +async def test_send() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock() # type: ignore + p = PropertyMock(return_value=0) + type(t._conn).out_waiting = p # type: ignore + + r = EchoWrite(d="Hello pytest!") + await t.send(r.BYTES) + + # Raw transport writes the bytes verbatim - no encoding. + t._conn.write.assert_called_once_with(r.BYTES) + p.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_send_waits_for_tx_drain() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock() # type: ignore + p = PropertyMock(side_effect=(1, 0)) + type(t._conn).out_waiting = p # type: ignore + + await t.send(EchoWrite(d="x").BYTES) + assert p.call_count == 2 + + +@pytest.mark.asyncio +async def test_send_too_large_raises() -> None: + t = SMPSerialRawTransport(mtu=16) + with pytest.raises(ValueError): + await t.send(b"\x00" * 32) + + +@pytest.mark.asyncio +async def test_send_disconnected_raises() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock(side_effect=SerialException("disconnected")) # type: ignore + + with pytest.raises(SMPTransportDisconnected): + await t.send(EchoWrite(d="x").BYTES) + + +@pytest.mark.asyncio +async def test_receive_single_packet() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello pytest!") # type: ignore + t._conn.read_all = MagicMock(side_effect=[m.BYTES]) # type: ignore + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_fragmented() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello pytest!") # type: ignore + fragments = [ + m.BYTES[:3], # less than a header + m.BYTES[3:8], # completes the header but no payload yet + m.BYTES[8:10], + m.BYTES[10:], # rest of payload + ] + t._conn.read_all = MagicMock(side_effect=fragments) # type: ignore + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_byte_at_a_time() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hi") # type: ignore + t._conn.read_all = MagicMock( # type: ignore + side_effect=[bytes([b]) for b in m.BYTES] + ) + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_consecutive_messages() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m1 = EchoWrite._Response.get_default()(sequence=0, r="SMP Message 1") # type: ignore + m2 = EchoWrite._Response.get_default()(sequence=1, r="SMP Message 2") # type: ignore + m3 = EchoWrite._Response.get_default()(sequence=2, r="SMP Message 3") # type: ignore + + # Each receive() reads one full message, just like a normal request/response loop. + t._conn.read_all = MagicMock(side_effect=[m1.BYTES, m2.BYTES, m3.BYTES]) # type: ignore + + assert await t.receive() == m1.BYTES + assert await t.receive() == m2.BYTES + assert await t.receive() == m3.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_overrun_raises() -> None: + """A single read returning more bytes than the header advertises is an error. + + SMP is strictly request/response; the server should never send unsolicited bytes. + """ + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello!") # type: ignore + t._conn.read_all = MagicMock(side_effect=[m.BYTES + b"\x00\x01\x02"]) # type: ignore + + with pytest.raises(SMPClientException): + await t.receive() + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_polls_when_nothing_available() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="ok") # type: ignore + t._conn.read_all = MagicMock(side_effect=[b"", b"", m.BYTES]) # type: ignore + + received = await t.receive() + assert received == m.BYTES + assert t._conn.read_all.call_count >= 3 + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_oversized_header_raises() -> None: + """A header claiming more bytes than max_unencoded_size is rejected. + + Defensive bound against noisy or corrupted UART traffic that would + otherwise cause an unbounded wait. + """ + t = SMPSerialRawTransport(mtu=64) + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + bogus_header = smphdr.Header( + op=smphdr.OP.WRITE_RSP, + version=smphdr.Version.V2, + flags=smphdr.Flag(0), + length=10_000, + group_id=smphdr.GroupId.OS_MANAGEMENT, + sequence=0, + command_id=smphdr.CommandId.OSManagement.ECHO, + ).BYTES + t._conn.read_all = MagicMock(side_effect=[bogus_header]) # type: ignore + + with pytest.raises(SMPClientException): + await t.receive() + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_disconnected_raises() -> None: + t = SMPSerialRawTransport() + t._conn.read_all = MagicMock(side_effect=SerialException("disconnected")) # type: ignore + + with pytest.raises(SMPTransportDisconnected): + await t.receive() + + +@pytest.mark.asyncio +async def test_send_and_receive() -> None: + t = SMPSerialRawTransport() + t.send = AsyncMock() # type: ignore + t.receive = AsyncMock() # type: ignore + + await t.send_and_receive(b"some data") + + t.send.assert_awaited_once_with(b"some data") + t.receive.assert_awaited_once_with() diff --git a/tests/test_smp_serial_transport.py b/tests/test_smp_serial_transport.py index df64dab..436051f 100644 --- a/tests/test_smp_serial_transport.py +++ b/tests/test_smp_serial_transport.py @@ -18,7 +18,7 @@ @pytest.fixture(autouse=True) def mock_serial() -> Generator[None, Any, None]: - with patch("smpclient.transport.serial.Serial"): + with patch("smpclient.transport.serial.common.Serial"): yield