diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index 7ac0084281..ff03256775 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -321,18 +321,51 @@ def _check_errors(self) -> None: if exc and isinstance(exc, Exception): self._stored_exception = exc - def _cleanup_tasks(self) -> None: - if self._listener_task and not self._listener_task.done(): - self._listener_task.cancel() + async def _cleanup_tasks(self) -> None: + current_task = asyncio.current_task() + tasks_to_await: set[asyncio.Task[Any]] = set() - if self._process_events_task and not self._process_events_task.done(): - self._process_events_task.cancel() + while True: + while True: + tasks = [ + task + for task in ( + self._listener_task, + self._process_events_task, + self._stream_audio_task, + self._connection_task, + ) + if task is not None and task is not current_task and task not in tasks_to_await + ] + if not tasks: + break + + for task in tasks: + if not task.done(): + task.cancel() + + tasks_to_await.update(tasks) + await asyncio.sleep(0) - if self._stream_audio_task and not self._stream_audio_task.done(): - self._stream_audio_task.cancel() + if not tasks_to_await: + return - if self._connection_task and not self._connection_task.done(): - self._connection_task.cancel() + for task in tasks_to_await: + if not task.done(): + task.cancel() + + await asyncio.gather(*tasks_to_await, return_exceptions=True) + + if all( + task is None or task is current_task or task in tasks_to_await + for task in ( + self._listener_task, + self._process_events_task, + self._stream_audio_task, + self._connection_task, + ) + ): + return async def transcribe_turns(self) -> AsyncIterator[str]: self._connection_task = asyncio.create_task(self._process_websocket_connection()) @@ -367,7 +400,7 @@ async def close(self) -> None: if self._websocket: await self._websocket.close() - self._cleanup_tasks() + await self._cleanup_tasks() class OpenAISTTModel(STTModel): diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index cd503c60f2..368ebff9bf 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -165,6 +165,42 @@ async def test_stream_audio_sends_correct_json(): await session.close() +@pytest.mark.asyncio +async def test_close_awaits_cancelled_background_tasks(): + input_audio = await FakeStreamedAudioInput.get(count=0) + stt_settings = STTModelSettings() + session = OpenAISTTTranscriptionSession( + input=input_audio, + client=AsyncMock(api_key="FAKE_KEY"), + model="whisper-1", + settings=stt_settings, + trace_include_sensitive_data=False, + trace_include_sensitive_audio_data=False, + ) + + cleanup_events = [asyncio.Event() for _ in range(4)] + + async def wait_until_cancelled(cleanup_event: asyncio.Event) -> None: + try: + await asyncio.Event().wait() + finally: + cleanup_event.set() + + tasks = [asyncio.create_task(wait_until_cancelled(event)) for event in cleanup_events] + ( + session._listener_task, + session._process_events_task, + session._stream_audio_task, + session._connection_task, + ) = tasks + + await asyncio.sleep(0) + await session.close() + + assert all(task.done() for task in tasks) + assert all(event.is_set() for event in cleanup_events) + + @pytest.mark.asyncio @pytest.mark.parametrize( "created,updated,completed",