diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 7e1258ff..afebaa77 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -11,7 +11,9 @@ ## New Features - +- There's a new `Oneshot` channel, which returns a sender and a receiver. A single message can be sent using the sender, after which it will be closed. And the receiver will close as soon as the message is received. + +- `Sender`s now have an `aclose`, which must be called, when they are no-longer needed. ## Bug Fixes diff --git a/src/frequenz/channels/__init__.py b/src/frequenz/channels/__init__.py index 87c86a34..eaf125d1 100644 --- a/src/frequenz/channels/__init__.py +++ b/src/frequenz/channels/__init__.py @@ -80,7 +80,7 @@ """ from ._anycast import Anycast -from ._broadcast import Broadcast +from ._broadcast import Broadcast, BroadcastChannel from ._exceptions import ChannelClosedError, ChannelError, Error from ._generic import ( ChannelMessageT, @@ -92,6 +92,7 @@ ) from ._latest_value_cache import LatestValueCache from ._merge import Merger, merge +from ._oneshot import OneshotChannel, OneshotReceiver, OneshotSender from ._receiver import Receiver, ReceiverError, ReceiverStoppedError from ._select import ( Selected, @@ -100,19 +101,32 @@ select, selected_from, ) -from ._sender import Sender, SenderError +from ._sender import ( + ClonableSender, + ClonableSubscribableSender, + Sender, + SenderClosedError, + SenderError, + SubscribableSender, +) __all__ = [ "Anycast", "Broadcast", + "BroadcastChannel", "ChannelClosedError", "ChannelError", "ChannelMessageT", + "ClonableSender", + "ClonableSubscribableSender", "Error", "ErroredChannelT_co", "LatestValueCache", "MappedMessageT_co", "Merger", + "OneshotChannel", + "OneshotReceiver", + "OneshotSender", "Receiver", "ReceiverError", "ReceiverMessageT_co", @@ -120,9 +134,11 @@ "SelectError", "Selected", "Sender", + "SenderClosedError", "SenderError", "SenderMessageT_co", "SenderMessageT_contra", + "SubscribableSender", "UnhandledSelectedError", "merge", "select", diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index b5184a3f..5c3f8284 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -15,7 +15,7 @@ from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError -from ._sender import Sender, SenderError +from ._sender import Sender, SenderClosedError, SenderError _logger = logging.getLogger(__name__) @@ -327,6 +327,9 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._channel: Anycast[_T] = channel """The channel that this sender belongs to.""" + self._closed: bool = False + """Whether the sender is closed.""" + @override async def send(self, message: _T, /) -> None: """Send a message across the channel. @@ -343,7 +346,11 @@ async def send(self, message: _T, /) -> None: SenderError: If the underlying channel was closed. A [ChannelClosedError][frequenz.channels.ChannelClosedError] is set as the cause. + SenderClosedError: If this sender was closed. """ + if self._closed: + raise SenderClosedError(self) + # pylint: disable=protected-access if self._channel._closed: raise SenderError("The channel was closed", self) from ChannelClosedError( @@ -367,6 +374,16 @@ async def send(self, message: _T, /) -> None: self._channel._recv_cv.notify(1) # pylint: enable=protected-access + @override + async def aclose(self) -> None: + """Close this sender. + + After closing, the sender will not be able to send any more messages. Any + attempt to send a message through a closed sender will raise a + [SenderError][frequenz.channels.SenderError]. + """ + self._closed = True + def __str__(self) -> str: """Return a string representation of this sender.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index 2c167d5e..73f908d0 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -16,12 +16,15 @@ from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError -from ._sender import Sender, SenderError +from ._sender import ClonableSubscribableSender, SenderClosedError, SenderError _logger = logging.getLogger(__name__) -class Broadcast(Generic[ChannelMessageT]): +@deprecated("Please use BroadcastChannel channel instead.") +class Broadcast( # pylint: disable=too-many-instance-attributes + Generic[ChannelMessageT] +): """A channel that deliver all messages to all receivers. # Description @@ -184,7 +187,13 @@ async def main() -> None: ``` """ - def __init__(self, *, name: str, resend_latest: bool = False) -> None: + def __init__( + self, + *, + name: str, + resend_latest: bool = False, + auto_close: bool = False, + ) -> None: """Initialize this channel. Args: @@ -197,6 +206,8 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None: wait for the next message on the channel to arrive. It is safe to be set in data/reporting channels, but is not recommended for use in channels that stream control instructions. + auto_close: If True, the channel will be closed when all senders or all + receivers are closed. """ self._name: str = name """The name of the broadcast channel. @@ -207,8 +218,11 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None: self._recv_cv: Condition = Condition() """The condition to wait for data in the channel's buffer.""" + self._sender_count: int = 0 + """The number of senders attached to this channel.""" + self._receivers: dict[ - int, weakref.ReferenceType[_Receiver[ChannelMessageT]] + int, weakref.ReferenceType[BroadcastReceiver[ChannelMessageT]] ] = {} """The receivers attached to the channel, indexed by their hash().""" @@ -218,6 +232,9 @@ def __init__(self, *, name: str, resend_latest: bool = False) -> None: self._latest: ChannelMessageT | None = None """The latest message sent to the channel.""" + self._auto_close_enabled: bool = auto_close + """Whether to close the channel when all senders or all receivers are closed.""" + self.resend_latest: bool = resend_latest """Whether to resend the latest message to new receivers. @@ -269,13 +286,13 @@ async def close(self) -> None: # noqa: D402 """Close the channel, deprecated alias for `aclose()`.""" # noqa: D402 return await self.aclose() - def new_sender(self) -> Sender[ChannelMessageT]: + def new_sender(self) -> BroadcastSender[ChannelMessageT]: """Return a new sender attached to this channel.""" - return _Sender(self) + return BroadcastSender(self) def new_receiver( self, *, name: str | None = None, limit: int = 50, warn_on_overflow: bool = True - ) -> Receiver[ChannelMessageT]: + ) -> BroadcastReceiver[ChannelMessageT]: """Return a new receiver attached to this channel. Broadcast receivers have their own buffer, and when messages are not @@ -291,7 +308,7 @@ def new_receiver( Returns: A new receiver attached to this channel. """ - recv: _Receiver[ChannelMessageT] = _Receiver( + recv: BroadcastReceiver[ChannelMessageT] = BroadcastReceiver( self, name=name, limit=limit, warn_on_overflow=warn_on_overflow ) self._receivers[hash(recv)] = weakref.ref(recv) @@ -317,7 +334,7 @@ def __repr__(self) -> str: _T = TypeVar("_T") -class _Sender(Sender[_T]): +class BroadcastSender(ClonableSubscribableSender[_T]): """A sender to send messages to the broadcast channel. Should not be created directly, but through the @@ -334,6 +351,16 @@ def __init__(self, channel: Broadcast[_T], /) -> None: self._channel: Broadcast[_T] = channel """The broadcast channel this sender belongs to.""" + self._closed: bool = False + """Whether this sender is closed.""" + + self._channel._sender_count += 1 + + @property + def sender_count(self) -> int: + """Return the number of open senders attached to this sender's channel.""" + return self._channel._sender_count # pylint: disable=protected-access + @override async def send(self, message: _T, /) -> None: """Send a message to all broadcast receivers. @@ -345,12 +372,19 @@ async def send(self, message: _T, /) -> None: SenderError: If the underlying channel was closed. A [ChannelClosedError][frequenz.channels.ChannelClosedError] is set as the cause. + SenderClosedError: If this sender was closed. """ + if self._closed: + raise SenderClosedError(self) # pylint: disable=protected-access if self._channel._closed: raise SenderError("The channel was closed", self) from ChannelClosedError( self._channel ) + if self._channel._auto_close_enabled and len(self._channel._receivers) == 0: + raise SenderError("The channel was closed", self) from ChannelClosedError( + self._channel + ) self._channel._latest = message stale_refs = [] for _hash, recv_ref in self._channel._receivers.items(): @@ -365,6 +399,47 @@ async def send(self, message: _T, /) -> None: self._channel._recv_cv.notify_all() # pylint: enable=protected-access + @override + async def aclose(self) -> None: + """Close this sender. + + After a sender is closed, it can no longer be used to send messages. Any + attempt to send a message through a closed sender will raise a + [SenderClosedError][frequenz.channels.SenderClosedError]. + """ + if self._closed: + return + self._closed = True + self._channel._sender_count -= 1 + + if ( + self._channel._sender_count == 0 # pylint: disable=protected-access + and self._channel._auto_close_enabled # pylint: disable=protected-access + ): + await self._channel.aclose() + + def __del__(self) -> None: + """Clean up this sender.""" + if not self._closed: + self._channel._sender_count -= 1 + + @override + def clone(self) -> BroadcastSender[_T]: + """Return a clone of this sender.""" + return BroadcastSender(self._channel) + + @override + def subscribe( + self, + name: str | None = None, + limit: int = 50, + warn_on_overflow: bool = True, + ) -> BroadcastReceiver[_T]: + """Return a new receiver attached to this sender's channel.""" + return self._channel.new_receiver( + name=name, limit=limit, warn_on_overflow=warn_on_overflow + ) + def __str__(self) -> str: """Return a string representation of this sender.""" return f"{self._channel}:{type(self).__name__}" @@ -374,7 +449,7 @@ def __repr__(self) -> str: return f"{type(self).__name__}({self._channel!r})" -class _Receiver(Receiver[_T]): +class BroadcastReceiver(Receiver[_T]): """A receiver to receive messages from the broadcast channel. Should not be created directly, but through the @@ -476,6 +551,11 @@ async def ready(self) -> bool: while len(self._q) == 0: if self._channel._closed or self._closed: return False + if self._channel._auto_close_enabled and ( + self._channel._sender_count == 0 or len(self._channel._receivers) == 0 + ): + await self._channel.aclose() + return False async with self._channel._recv_cv: await self._channel._recv_cv.wait() return True @@ -525,3 +605,32 @@ def __repr__(self) -> str: f"{type(self).__name__}(name={self._name!r}, limit={limit!r}, " f"{self._channel!r}):" ) + + +class BroadcastChannel( + tuple[BroadcastSender[ChannelMessageT], BroadcastReceiver[ChannelMessageT]] +): + """A broadcast channel, deprecated alias for Broadcast.""" + + def __new__( + cls, + name: str, + resend_latest: bool = False, + limit: int = 50, + warn_on_overflow: bool = True, + ) -> BroadcastChannel[ChannelMessageT]: + """Create a new broadcast channel, deprecated alias for Broadcast.""" + channel = Broadcast[ChannelMessageT]( + name=name, resend_latest=resend_latest, auto_close=True + ) + return tuple.__new__( + cls, + ( + channel.new_sender(), + channel.new_receiver( + name=f"{name}_receiver", + limit=limit, + warn_on_overflow=warn_on_overflow, + ), + ), + ) diff --git a/src/frequenz/channels/_oneshot.py b/src/frequenz/channels/_oneshot.py new file mode 100644 index 00000000..15c8e3a6 --- /dev/null +++ b/src/frequenz/channels/_oneshot.py @@ -0,0 +1,142 @@ +# License: MIT +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH + +"""A channel that can send a single message.""" + +from __future__ import annotations + +import asyncio +import typing + +from ._generic import ChannelMessageT +from ._receiver import Receiver, ReceiverStoppedError +from ._sender import Sender, SenderClosedError + + +class _Empty: + """A sentinel indicating that no message has been sent.""" + + +_EMPTY = _Empty() + + +class _Oneshot(typing.Generic[ChannelMessageT]): + """Internal representation of a one-shot channel. + + A one-shot channel is a channel that can only send one message. After the first + message is sent, the sender is closed and any further attempts to send a message + will raise a `SenderClosedError`. + """ + + def __init__(self) -> None: + """Create a new one-shot channel.""" + self.message: ChannelMessageT | _Empty = _EMPTY + self.closed: bool = False + self.drained: bool = False + self.event: asyncio.Event = asyncio.Event() + + +class OneshotSender(Sender[ChannelMessageT]): + """A sender for a one-shot channel.""" + + def __init__(self, channel: _Oneshot[ChannelMessageT]) -> None: + """Initialize this sender.""" + self._channel = channel + + async def send(self, message: ChannelMessageT, /) -> None: + """Send a message through this sender.""" + if self._channel.closed: + raise SenderClosedError(self) + self._channel.message = message + self._channel.closed = True + self._channel.event.set() + + async def aclose(self) -> None: + """Close this sender.""" + self._channel.closed = True + if isinstance(self._channel.message, _Empty): + self._channel.drained = True + self._channel.event.set() + + +class OneshotReceiver(Receiver[ChannelMessageT]): + """A receiver for a one-shot channel.""" + + def __init__(self, channel: _Oneshot[ChannelMessageT]) -> None: + """Initialize this receiver.""" + self._channel = channel + + async def ready(self) -> bool: + """Check if a message is ready to be received. + + Returns: + `True` if a message is ready to be received, `False` if the sender + is closed and no message will be sent. + """ + if self._channel.drained: + return False + while not self._channel.closed: + await self._channel.event.wait() + if isinstance(self._channel.message, _Empty): + return False + return True + + def consume(self) -> ChannelMessageT: + """Consume a message from this receiver. + + Returns: + The message that was sent through this channel. + + Raises: + ReceiverStoppedError: If the sender was closed without sending a message. + """ + if self._channel.drained: + raise ReceiverStoppedError(self) + + assert not isinstance( + self._channel.message, _Empty + ), "`consume()` must be preceded by a call to `ready()`." + + self._channel.drained = True + self._channel.event.clear() + return self._channel.message + + +class OneshotChannel( + tuple[OneshotSender[ChannelMessageT], OneshotReceiver[ChannelMessageT]] +): + """A channel that can send a single message. + + A one-shot channel is a channel that can only send one message. After the first + message is sent, the sender is closed and any further attempts to send a message + will raise a `SenderClosedError`. + + # Example + + This example demonstrates how to use a one-shot channel to send a message + from one task to another. + + ```python + import asyncio + + from frequenz.channels import OneshotChannel, OneshotSender + + async def send(sender: OneshotSender[int]) -> None: + await sender.send(42) + + async def main() -> None: + sender, receiver = OneshotChannel[int]() + + async with asyncio.TaskGroup() as tg: + tg.create_task(send(sender)) + assert await receiver.receive() == 42 + + asyncio.run(main()) + ``` + """ + + def __new__(cls) -> OneshotChannel[ChannelMessageT]: + """Create a new one-shot channel.""" + channel = _Oneshot[ChannelMessageT]() + + return tuple.__new__(cls, (OneshotSender(channel), OneshotReceiver(channel))) diff --git a/src/frequenz/channels/_sender.py b/src/frequenz/channels/_sender.py index e225e94c..4c1a6d07 100644 --- a/src/frequenz/channels/_sender.py +++ b/src/frequenz/channels/_sender.py @@ -49,11 +49,14 @@ ``` """ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Generic from ._exceptions import Error from ._generic import SenderMessageT_co, SenderMessageT_contra +from ._receiver import Receiver class Sender(ABC, Generic[SenderMessageT_contra]): @@ -70,6 +73,15 @@ async def send(self, message: SenderMessageT_contra, /) -> None: SenderError: If there was an error sending the message. """ + @abstractmethod + async def aclose(self) -> None: + """Close this sender. + + After a sender is closed, it can no longer be used to send messages. Any + attempt to send a message through a closed sender will raise a + [SenderClosedError][frequenz.channels.SenderClosedError]. + """ + class SenderError(Error, Generic[SenderMessageT_co]): """An error that originated in a [Sender][frequenz.channels.Sender]. @@ -88,3 +100,47 @@ def __init__(self, message: str, sender: Sender[SenderMessageT_co]): super().__init__(message) self.sender: Sender[SenderMessageT_co] = sender """The sender where the error happened.""" + + +class SenderClosedError(SenderError[SenderMessageT_co]): + """An error indicating that a send operation was attempted on a closed sender.""" + + def __init__(self, sender: Sender[SenderMessageT_co]): + """Initialize this error. + + Args: + sender: The [Sender][frequenz.channels.Sender] that was closed. + """ + super().__init__("Sender is closed", sender) + + +class SubscribableSender(Sender[SenderMessageT_contra], ABC): + """A [Sender][frequenz.channels.Sender] that can be subscribed to.""" + + @abstractmethod + def subscribe(self) -> Receiver[SenderMessageT_contra]: + """Subscribe to this sender. + + Returns: + A new sender that sends messages to the same channel as this sender. + """ + + +class ClonableSender(Sender[SenderMessageT_contra], ABC): + """A [Sender][frequenz.channels.Sender] that can be cloned.""" + + @abstractmethod + def clone(self) -> ClonableSender[SenderMessageT_contra]: + """Clone this sender. + + Returns: + A new sender that sends messages to the same channel as this sender. + """ + + +class ClonableSubscribableSender( + SubscribableSender[SenderMessageT_contra], + ClonableSender[SenderMessageT_contra], + ABC, +): + """A [Sender][frequenz.channels.Sender] that can be both cloned and subscribed to.""" diff --git a/src/frequenz/channels/experimental/_relay_sender.py b/src/frequenz/channels/experimental/_relay_sender.py index 398ba8d5..0a5d8063 100644 --- a/src/frequenz/channels/experimental/_relay_sender.py +++ b/src/frequenz/channels/experimental/_relay_sender.py @@ -7,7 +7,7 @@ to the senders it was created with. """ -import typing +import asyncio from typing_extensions import override @@ -15,7 +15,7 @@ from .._sender import Sender -class RelaySender(typing.Generic[SenderMessageT_contra], Sender[SenderMessageT_contra]): +class RelaySender(Sender[SenderMessageT_contra]): """A Sender for sending messages to multiple senders. The `RelaySender` class takes multiple senders and forwards all the messages sent to @@ -57,3 +57,8 @@ async def send(self, message: SenderMessageT_contra, /) -> None: """ for sender in self._senders: await sender.send(message) + + @override + async def aclose(self) -> None: + """Close this sender.""" + await asyncio.gather(*(sender.aclose() for sender in self._senders)) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 0cc89f33..1f391844 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -12,6 +12,7 @@ from frequenz.channels import ( Broadcast, + BroadcastChannel, ChannelClosedError, Receiver, ReceiverStoppedError, @@ -107,7 +108,7 @@ async def test_broadcast_after_close() -> None: async def test_broadcast_overflow() -> None: """Ensure messages sent to full broadcast receivers get dropped.""" from frequenz.channels._broadcast import ( # pylint: disable=import-outside-toplevel - _Receiver, + BroadcastReceiver, ) bcast: Broadcast[int] = Broadcast(name="meter_5") @@ -117,9 +118,9 @@ async def test_broadcast_overflow() -> None: sender = bcast.new_sender() big_receiver = bcast.new_receiver(name="named-recv", limit=big_recv_size) - assert isinstance(big_receiver, _Receiver) + assert isinstance(big_receiver, BroadcastReceiver) small_receiver = bcast.new_receiver(limit=small_recv_size) - assert isinstance(small_receiver, _Receiver) + assert isinstance(small_receiver, BroadcastReceiver) async def drain_receivers() -> tuple[int, int]: big_sum = 0 @@ -425,3 +426,50 @@ async def test_broadcast_close_receiver() -> None: with pytest.raises(ReceiverStoppedError): _ = await receiver_2.receive() + + +async def test_broadcast_auto_close_1() -> None: + """Ensure broadcast auto close works when all receivers are closed.""" + sender, receiver = BroadcastChannel[int](name="auto-close-test") + + receiver_2 = sender.subscribe() + + await sender.send(1) + + assert (await receiver.receive()) == 1 + assert (await receiver_2.receive()) == 1 + + receiver.close() + + await sender.send(2) + + assert (await receiver_2.receive()) == 2 + + receiver_2.close() + + with pytest.raises(SenderError) as excinfo: + await sender.send(3) + assert isinstance(excinfo.value.__cause__, ChannelClosedError) + + +async def test_broadcast_auto_close_2() -> None: + """Ensure broadcast auto close works when all senders are closed.""" + sender, receiver = BroadcastChannel[int](name="auto-close-test") + + await sender.send(1) + + assert (await receiver.receive()) == 1 + + sender_2 = sender.clone() + + await sender.aclose() + + await sender_2.send(2) + + await sender_2.aclose() + + assert (await receiver.receive()) == 2 + + with pytest.raises(ReceiverStoppedError) as excinfo: + await receiver.receive() + assert isinstance(excinfo.value.__cause__, ChannelClosedError) diff --git a/tests/test_oneshot.py b/tests/test_oneshot.py new file mode 100644 index 00000000..728bdcde --- /dev/null +++ b/tests/test_oneshot.py @@ -0,0 +1,87 @@ +# License: MIT +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH + +"""Tests for the oneshot channel.""" + +import asyncio + +import pytest + +from frequenz.channels import ( + OneshotChannel, + ReceiverStoppedError, + SenderClosedError, +) + + +async def test_oneshot_recv_after_send() -> None: + """Test the oneshot function. + + `receiver.receive()` is called after `sender.send()`. + """ + sender, receiver = OneshotChannel[int]() + + await sender.send(42) + assert await receiver.receive() == 42 + + with pytest.raises(SenderClosedError): + await sender.send(43) + with pytest.raises(ReceiverStoppedError): + await receiver.receive() + + +async def test_oneshot_recv_before_send() -> None: + """Test the oneshot function. + + `receiver.receive()` is called before `sender.send()`. + """ + sender, receiver = OneshotChannel[int]() + + task = asyncio.create_task(receiver.receive()) + + # Give the receiver a chance to start waiting + await asyncio.sleep(0.0) + + await sender.send(42) + assert await task == 42 + + with pytest.raises(SenderClosedError): + await sender.send(43) + with pytest.raises(ReceiverStoppedError): + await receiver.receive() + + +async def test_oneshot_recv_after_sender_closed() -> None: + """Test that closing sender works without sending a message. + + `receiver.receive()` is called after `sender.aclose()`. + """ + sender, receiver = OneshotChannel[int]() + + await sender.aclose() + + with pytest.raises(ReceiverStoppedError): + await receiver.receive() + with pytest.raises(SenderClosedError): + await sender.send(4) + + +async def test_oneshot_recv_before_sender_closed() -> None: + """Test that closing sender works without sending a message. + + `receiver.receive()` is called before `sender.aclose()`. + """ + sender, receiver = OneshotChannel[int]() + + task = asyncio.create_task(receiver.receive()) + + # Give the receiver a chance to start waiting + await asyncio.sleep(0.0) + + await sender.aclose() + + with pytest.raises(ReceiverStoppedError): + await task + + with pytest.raises(SenderClosedError): + await sender.send(4)