diff --git a/pyproject.toml b/pyproject.toml index 484b5a3..35675cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,12 @@ classifiers = [ dependencies = [ "smp>=4.0.2", "intelhex>=2.3.0", -"async-timeout>=4.0.3; python_version < '3.11'", +"async-timeout>=5.0.1; python_version < '3.11'", ] [project.optional-dependencies] serial = ["pyserial>=3.5"] -ble = ["bleak>=2.0.0"] +ble = ["bleak>=3.0.2,<4"] udp = [] all = ["smpclient[serial,ble,udp]"] diff --git a/src/smpclient/transport/ble.py b/src/smpclient/transport/ble.py index 0cc2c3f..d72d0d0 100644 --- a/src/smpclient/transport/ble.py +++ b/src/smpclient/transport/ble.py @@ -4,7 +4,8 @@ import logging import re import sys -from typing import Final, Protocol, TypeGuard +from collections.abc import Coroutine +from typing import Any, Final, Protocol, TypeGuard, TypeVar from uuid import UUID try: @@ -66,6 +67,8 @@ class SMPBLETransportNotSMPServer(SMPBLETransportException): logger = logging.getLogger(__name__) +_T = TypeVar("_T") + class SMPBLETransport(SMPTransport): """A Bluetooth Low Energy (BLE) SMPTransport.""" @@ -84,6 +87,13 @@ def __init__(self, winrt: WinRTClientArgs = {}) -> None: @override async def connect(self, address: str, timeout_s: float) -> None: + try: + await asyncio.wait_for(self._connect(address, timeout_s), timeout=timeout_s) + except (Exception, asyncio.CancelledError): + await self._best_effort_disconnect() + raise + + async def _connect(self, address: str, timeout_s: float) -> None: logger.debug(f"Scanning for {address=}") device: BLEDevice | None = ( await BleakScanner.find_device_by_address(address, timeout=timeout_s) @@ -96,6 +106,7 @@ async def connect(self, address: str, timeout_s: float) -> None: device, services=(str(SMP_SERVICE_UUID),), winrt=self._winrt, + timeout=timeout_s, disconnected_callback=self._set_disconnected_event, ) else: @@ -139,7 +150,9 @@ async def connect(self, address: str, timeout_s: float) -> None: self._smp_characteristic = smp_characteristic logger.debug(f"Starting notify on {SMP_CHARACTERISTIC_UUID=}") - await self._client.start_notify(SMP_CHARACTERISTIC_UUID, self._notify_callback) + await self._await_or_disconnect( + self._client.start_notify(SMP_CHARACTERISTIC_UUID, self._notify_callback) + ) logger.debug(f"Started notify on {SMP_CHARACTERISTIC_UUID=}") @override @@ -246,3 +259,37 @@ async def _notify_or_disconnect(self) -> None: raise SMPTransportDisconnected( f"{self.__class__.__name__} disconnected from {self._client.address}" ) + + async def _await_or_disconnect(self, coro: Coroutine[Any, Any, _T]) -> _T: + """Await `coro`; raise `SMPTransportDisconnected` if the peer disconnects first. + + Guards GATT operations that can hang indefinitely when the peer + disconnects mid-flow (e.g. failed pairing) — see + https://github.com/intercreate/smpmgr/issues/97. + """ + op_task: Final = asyncio.create_task(coro) + disconnected_task: Final = asyncio.create_task(self._disconnected_event.wait()) + try: + done, _ = await asyncio.wait( + (op_task, disconnected_task), return_when=asyncio.FIRST_COMPLETED + ) + finally: + for task in (op_task, disconnected_task): + if not task.done(): + task.cancel() + await asyncio.gather(op_task, disconnected_task, return_exceptions=True) + if disconnected_task in done: + raise SMPTransportDisconnected( + f"{self.__class__.__name__} disconnected from {self._client.address}" + ) + return op_task.result() + + async def _best_effort_disconnect(self) -> None: + """Best-effort cleanup after a failed `connect()`; never raises.""" + client: Final = getattr(self, "_client", None) + if client is None: + return + try: + await client.disconnect() + except Exception: + logger.warning("Best-effort disconnect after failed connect raised", exc_info=True) diff --git a/tests/test_smp_ble_transport.py b/tests/test_smp_ble_transport.py index c451876..2250cdf 100644 --- a/tests/test_smp_ble_transport.py +++ b/tests/test_smp_ble_transport.py @@ -10,6 +10,7 @@ from bleak.backends.device import BLEDevice from smpclient.requests.os_management import EchoWrite +from smpclient.transport import SMPTransportDisconnected from smpclient.transport.ble import ( MAC_ADDRESS_PATTERN, SMP_CHARACTERISTIC_UUID, @@ -209,3 +210,110 @@ def test_max_unencoded_size_mcumgr_param() -> None: t._client = MagicMock(spec=BleakClient) t._smp_server_transport_buffer_size = 9001 assert t.max_unencoded_size == 9001 + + +class _HangingBleakClient: + """A `BleakClient` stand-in whose `start_notify` never returns. + + Reproduces the failure mode reported in intercreate/smpmgr#97: the BlueZ + `StartNotify` D-Bus call hangs indefinitely when the peer disconnects + mid-pairing. + """ + + def __new__(cls, *args: object, **kwargs: object) -> "_HangingBleakClient": # type: ignore[misc] # noqa: E501 + captured_callback = kwargs.get("disconnected_callback") + client = MagicMock(spec=BleakClient, name="HangingBleakClient") + client._backend = type("Backend", (), {})() + client.connect = AsyncMock(name="connect") + + async def _hang(*_a: object, **_kw: object) -> None: + await asyncio.Event().wait() # never fires + + client.start_notify = AsyncMock(side_effect=_hang) + client.disconnect = AsyncMock(name="disconnect") + client.address = "00:00:00:00:00:00" + client._captured_disconnected_callback = captured_callback # type: ignore[attr-defined] + return client + + +@patch( + "smpclient.transport.ble.BleakScanner.find_device_by_address", + return_value=BLEDevice("00:00:00:00:00:00", "name", None), +) +@patch("smpclient.transport.ble.BleakClient", new=_HangingBleakClient) +@pytest.mark.asyncio +async def test_connect_raises_on_peer_disconnect_during_start_notify( + _mock_find_device_by_address: MagicMock, +) -> None: + """Regression test for intercreate/smpmgr#97. + + When the peer disconnects mid-`start_notify` (e.g. failed pairing), `connect()` + must surface `SMPTransportDisconnected` rather than hang. + """ + t = SMPBLETransport() + + async def _trip_disconnect_callback() -> None: + # Wait until the transport reaches start_notify and clears the event, + # then simulate the bleak `disconnected_callback` firing. + while t._disconnected_event.is_set(): + await asyncio.sleep(0) + await asyncio.sleep(0) # let start_notify await begin + t._set_disconnected_event(t._client) + + connect_task = asyncio.create_task(t.connect("00:00:00:00:00:00", 5.0)) + trip_task = asyncio.create_task(_trip_disconnect_callback()) + + with pytest.raises(SMPTransportDisconnected): + await connect_task + await trip_task + + # `_best_effort_disconnect` should have been called to release the client. + t._client.disconnect.assert_awaited() # type: ignore[attr-defined] + + +@patch( + "smpclient.transport.ble.BleakScanner.find_device_by_address", + return_value=BLEDevice("00:00:00:00:00:00", "name", None), +) +@patch("smpclient.transport.ble.BleakClient", new=_HangingBleakClient) +@pytest.mark.asyncio +async def test_connect_raises_on_timeout_during_start_notify( + _mock_find_device_by_address: MagicMock, +) -> None: + """`connect()` must honor `timeout_s` even when `start_notify` hangs.""" + t = SMPBLETransport() + with pytest.raises(asyncio.TimeoutError): + await t.connect("00:00:00:00:00:00", 0.05) + t._client.disconnect.assert_awaited() # type: ignore[attr-defined] + + +@patch( + "smpclient.transport.ble.BleakScanner.find_device_by_address", + return_value=BLEDevice("00:00:00:00:00:00", "name", None), +) +@patch("smpclient.transport.ble.BleakClient", new=_HangingBleakClient) +@pytest.mark.asyncio +async def test_connect_does_not_leak_tasks_on_external_cancel( + _mock_find_device_by_address: MagicMock, +) -> None: + """Caller-driven cancellation must not leave `_await_or_disconnect` sub-tasks running.""" + t = SMPBLETransport() + tasks_before = {id(task) for task in asyncio.all_tasks()} + + connect_task = asyncio.create_task(t.connect("00:00:00:00:00:00", 60.0)) + while t._disconnected_event.is_set(): + await asyncio.sleep(0) # wait until BleakClient.connect() returned + await asyncio.sleep(0) # let start_notify await begin + + connect_task.cancel() + with pytest.raises(asyncio.CancelledError): + await connect_task + + # Let any cancellations propagate to the spawned sub-tasks. + for _ in range(5): + await asyncio.sleep(0) + + leaked = [ + task for task in asyncio.all_tasks() if id(task) not in tasks_before and not task.done() + ] + assert not leaked, f"sub-tasks leaked after external cancel: {leaked}" diff --git a/uv.lock b/uv.lock index 4badf63..bcdea01 100644 --- a/uv.lock +++ b/uv.lock @@ -182,7 +182,7 @@ wheels = [ [[package]] name = "bleak" -version = "2.1.1" +version = "3.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout", marker = "python_full_version < '3.11'" }, @@ -201,9 +201,9 @@ dependencies = [ { name = "winrt-windows-foundation-collections", marker = "sys_platform == 'win32'" }, { name = "winrt-windows-storage-streams", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/45/8a/5acbd4da6a5a301fab56ff6d6e9e6b6945e6e4a2d1d213898c21b1d3a19b/bleak-2.1.1.tar.gz", hash = "sha256:4600cc5852f2392ce886547e127623f188e689489c5946d422172adf80635cf9", size = 120634, upload-time = "2025-12-31T20:43:28.697Z" } +sdist = { url = "https://files.pythonhosted.org/packages/16/df/05a3f80ca8e3f7f5b0dba68a9e618147c909ccdba1468f07487dc8d72a9d/bleak-3.0.2.tar.gz", hash = "sha256:c2229cb8238d5876b4bd05c74bf7a1aea1f88da39d2e51ac9dfd5cc319d5265f", size = 125293, upload-time = "2026-05-02T23:01:04.066Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/99/fe/22aec895f040c1e457d6e6fcc79286fbb17d54602600ab2a58837bec7be1/bleak-2.1.1-py3-none-any.whl", hash = "sha256:61ac1925073b580c896a92a8c404088c5e5ec9dc3c5bd6fc17554a15779d83de", size = 141258, upload-time = "2025-12-31T20:43:27.302Z" }, + { url = "https://files.pythonhosted.org/packages/26/54/05aceb9cd80073805b3ed8522e3196e8cb22f70e741873fa51406c31f4e7/bleak-3.0.2-py3-none-any.whl", hash = "sha256:39092feb9e83f1df5ad2f88e837723c7211c982ce9e9cda6235104bc2ebe0d0d", size = 146490, upload-time = "2026-05-02T23:01:02.592Z" }, ] [[package]] @@ -687,6 +687,7 @@ dependencies = [ { name = "griffecli" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/04/56/28a0accac339c164b52a92c6cfc45a903acc0c174caa5c1713803467b533/griffe-2.0.0.tar.gz", hash = "sha256:c68979cd8395422083a51ea7cf02f9c119d889646d99b7b656ee43725de1b80f", size = 293906, upload-time = "2026-03-23T21:06:53.402Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214, upload-time = "2026-02-09T19:09:44.105Z" }, ] @@ -711,6 +712,7 @@ dependencies = [ { name = "colorama" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/a4/f8/2e129fd4a86e52e58eefe664de05e7d502decf766e7316cc9e70fdec3e18/griffecli-2.0.0.tar.gz", hash = "sha256:312fa5ebb4ce6afc786356e2d0ce85b06c1c20d45abc42d74f0cda65e159f6ef", size = 56213, upload-time = "2026-03-23T21:06:54.8Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345, upload-time = "2026-02-09T19:09:42.554Z" }, ] @@ -719,6 +721,7 @@ wheels = [ name = "griffelib" version = "2.0.0" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/06/eccbd311c9e2b3ca45dbc063b93134c57a1ccc7607c5e545264ad092c4a9/griffelib-2.0.0.tar.gz", hash = "sha256:e504d637a089f5cab9b5daf18f7645970509bf4f53eda8d79ed71cce8bd97934", size = 166312, upload-time = "2026-03-23T21:06:55.954Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, ] @@ -2018,9 +2021,9 @@ doc = [ [package.metadata] requires-dist = [ - { name = "async-timeout", marker = "python_full_version < '3.11'", specifier = ">=4.0.3" }, - { name = "bleak", marker = "extra == 'all'", specifier = ">=2.0.0" }, - { name = "bleak", marker = "extra == 'ble'", specifier = ">=2.0.0" }, + { name = "async-timeout", marker = "python_full_version < '3.11'", specifier = ">=5.0.1" }, + { name = "bleak", marker = "extra == 'all'", specifier = ">=3.0.2,<4" }, + { name = "bleak", marker = "extra == 'ble'", specifier = ">=3.0.2,<4" }, { name = "intelhex", specifier = ">=2.3.0" }, { name = "pyserial", marker = "extra == 'all'", specifier = ">=3.5" }, { name = "pyserial", marker = "extra == 'serial'", specifier = ">=3.5" },