Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 78 additions & 30 deletions bumble/transport/usb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
)

MAX_SCO_PACKET_SIZE = 1024
MAX_SCO_IN_PACKETS = 128
NUMBER_OF_SCO_IN_TRANSFERS = 2


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -388,20 +390,36 @@ async def terminate(self):
READ_SIZE = 4096


class ScoAccumulator:
def __init__(self, emit: Callable[[bytes], Any]) -> None:
class PacketSplitter:
"""Splitter than can parse a byte stream and extract packets that consist of a
header and a body, where the header includes an n-byte 'length' field at a
certain offset.
Extracted packets are emitted by calling a function passed to the constructor,
with the full packet (header + body) as argument.
"""

def __init__(
self, length_offset: int, length_size: int, emit: Callable[[bytes], Any]
) -> None:
self.emit = emit
self.packet = b''
self.length_offset = length_offset
self.length_size = length_size
self.header_size = length_offset + length_size

def feed(self, data: bytes) -> None:
while data:
# Accumulate until we have a complete 3-byte header
if (bytes_needed := 3 - len(self.packet)) > 0:
# Accumulate until we have a complete header
if (bytes_needed := self.header_size - len(self.packet)) > 0:
self.packet += data[:bytes_needed]
data = data[bytes_needed:]
continue
if len(self.packet) < self.header_size:
continue

packet_length = 3 + self.packet[2]
packet_length = self.header_size + int.from_bytes(
self.packet[self.length_offset : self.length_offset + self.length_size],
'little',
)
bytes_needed = packet_length - len(self.packet)
self.packet += data[:bytes_needed]
data = data[bytes_needed:]
Expand All @@ -411,6 +429,24 @@ def feed(self, data: bytes) -> None:
self.packet = b''


class ScoPacketSplitter(PacketSplitter):
def __init__(self, emit: Callable[[bytes], Any]) -> None:
# The length field is 1 byte at offset 2 in the HCI SCO packet header
super().__init__(length_offset=2, length_size=1, emit=emit)


class EventPacketSplitter(PacketSplitter):
def __init__(self, emit: Callable[[bytes], Any]) -> None:
# The length field is 1 byte at offset 1 in the HCI Event packet header
super().__init__(length_offset=1, length_size=1, emit=emit)


class AclPacketSplitter(PacketSplitter):
def __init__(self, emit: Callable[[bytes], Any]) -> None:
# The length field is 2 bytes at offset 2 in the HCI ACL packet header
super().__init__(length_offset=2, length_size=2, emit=emit)


class UsbPacketSource(asyncio.Protocol, BaseSource):
def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in):
super().__init__()
Expand All @@ -421,17 +457,23 @@ def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in):
self.bulk_in = bulk_in
self.bulk_in_transfer = None
self.isochronous_in = isochronous_in
self.isochronous_in_transfer = None
self.isochronous_accumulator = ScoAccumulator(
lambda packet: self.queue_packet(hci.HCI_SYNCHRONOUS_DATA_PACKET, packet)
)
self.isochronous_in_transfers = []
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.done = {
hci.HCI_EVENT_PACKET: asyncio.Event(),
hci.HCI_ACL_DATA_PACKET: asyncio.Event(),
hci.HCI_SYNCHRONOUS_DATA_PACKET: asyncio.Event(),
self.done = {}
self.splitters = {
hci.HCI_EVENT_PACKET: EventPacketSplitter(
lambda packet: self.queue_packet(hci.HCI_EVENT_PACKET, packet)
),
hci.HCI_ACL_DATA_PACKET: AclPacketSplitter(
lambda packet: self.queue_packet(hci.HCI_ACL_DATA_PACKET, packet)
),
hci.HCI_SYNCHRONOUS_DATA_PACKET: ScoPacketSplitter(
lambda packet: self.queue_packet(
hci.HCI_SYNCHRONOUS_DATA_PACKET, packet
)
),
}
self.closed = False
self.lock = threading.Lock()
Expand All @@ -445,6 +487,7 @@ def start(self):
callback=self.transfer_callback,
user_data=hci.HCI_EVENT_PACKET,
)
self.done[self.interrupt_in_transfer] = asyncio.Event()
self.interrupt_in_transfer.submit()

self.bulk_in_transfer = self.device.getTransfer()
Expand All @@ -454,17 +497,21 @@ def start(self):
callback=self.transfer_callback,
user_data=hci.HCI_ACL_DATA_PACKET,
)
self.done[self.bulk_in_transfer] = asyncio.Event()
self.bulk_in_transfer.submit()

if self.isochronous_in is not None:
self.isochronous_in_transfer = self.device.getTransfer(iso_packets=16)
self.isochronous_in_transfer.setIsochronous(
self.isochronous_in.getAddress(),
16 * self.isochronous_in.getMaxPacketSize(),
callback=self.transfer_callback,
user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET,
)
self.isochronous_in_transfer.submit()
for _ in range(NUMBER_OF_SCO_IN_TRANSFERS):
transfer = self.device.getTransfer(iso_packets=MAX_SCO_IN_PACKETS)
transfer.setIsochronous(
self.isochronous_in.getAddress(),
MAX_SCO_IN_PACKETS * self.isochronous_in.getMaxPacketSize(),
callback=self.transfer_callback,
user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET,
)
self.isochronous_in_transfers.append(transfer)
self.done[transfer] = asyncio.Event()
transfer.submit()

self.dequeue_task = self.loop.create_task(self.dequeue())

Expand All @@ -490,6 +537,8 @@ def transfer_callback(self, transfer):
with self.lock:
if self.closed:
logger.debug("packet source closed, discarding transfer")
elif (splitter := self.splitters.get(packet_type)) is None:
logger.warning(f'no splitter for packet type {packet_type}')
else:
if packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET:
for iso_status, iso_buffer in transfer.iterISO():
Expand All @@ -503,11 +552,10 @@ def transfer_callback(self, transfer):
len(iso_buffer),
iso_buffer.hex(),
)
self.isochronous_accumulator.feed(iso_buffer)
splitter.feed(iso_buffer)
else:
self.queue_packet(
packet_type,
transfer.getBuffer()[: transfer.getActualLength()],
splitter.feed(
transfer.getBuffer()[: transfer.getActualLength()]
)

# Re-submit the transfer so we can receive more data
Expand All @@ -518,12 +566,12 @@ def transfer_callback(self, transfer):
self.loop.call_soon_threadsafe(self.on_transport_lost)
elif status == usb1.TRANSFER_CANCELLED:
logger.debug(f"IN[{packet_type}] transfer canceled")
self.loop.call_soon_threadsafe(self.done[packet_type].set)
self.loop.call_soon_threadsafe(self.done[transfer].set)
else:
logger.warning(
color(f'!!! IN[{packet_type}] transfer not completed', 'red')
)
self.loop.call_soon_threadsafe(self.done[packet_type].set)
self.loop.call_soon_threadsafe(self.done[transfer].set)
self.loop.call_soon_threadsafe(self.on_transport_lost)

async def dequeue(self):
Expand Down Expand Up @@ -552,7 +600,7 @@ async def terminate(self):
for transfer in (
self.interrupt_in_transfer,
self.bulk_in_transfer,
self.isochronous_in_transfer,
*self.isochronous_in_transfers,
):
if transfer is None:
continue
Expand All @@ -568,7 +616,7 @@ async def terminate(self):
f'waiting for IN[{packet_type}] transfer cancellation '
'to be done...'
)
await self.done[packet_type].wait()
await self.done[transfer].wait()
logger.debug(f'IN[{packet_type}] transfer cancellation done')
except usb1.USBError as error:
logger.debug(
Expand Down
65 changes: 64 additions & 1 deletion tests/transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest

from bumble import controller, device, hci, link, transport
from bumble.transport import common
from bumble.transport import common, usb


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -252,6 +252,69 @@ async def test_open_transport_with_metadata(spec):
await controller_transport.close()


# -----------------------------------------------------------------------------
def test_packet_splitter_complete():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
splitter.feed(packet)
assert emitted == [packet]


def test_packet_splitter_chunks():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
splitter.feed(packet[:4])
assert emitted == []
splitter.feed(packet[4:])
assert emitted == [packet]


def test_packet_splitter_multiple():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
splitter.feed(packet1 + packet2)
assert emitted == [packet1, packet2]


def test_packet_splitter_partial():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet1 = bytes([0x01, 0x00, 0x04, 0x00, 0x11, 0x22, 0x33, 0x44])
packet2 = bytes([0x02, 0x00, 0x02, 0x00, 0x55, 0x66])
splitter.feed(packet1 + packet2[:4])
assert emitted == [packet1]
splitter.feed(packet2[4:])
assert emitted == [packet1, packet2]


def test_packet_splitter_empty_payload():
emitted = []
splitter = usb.AclPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x00, 0x00])
splitter.feed(packet)
assert emitted == [packet]


def test_sco_packet_splitter():
emitted = []
splitter = usb.ScoPacketSplitter(emitted.append)
packet = bytes([0x01, 0x00, 0x03, 0x11, 0x22, 0x33])
splitter.feed(packet)
assert emitted == [packet]


def test_event_packet_splitter():
emitted = []
splitter = usb.EventPacketSplitter(emitted.append)
packet = bytes([0x04, 0x02, 0x11, 0x22])
splitter.feed(packet)
assert emitted == [packet]


# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_parser()
Expand Down
Loading