Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions src/agents/voice/models/openai_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,51 @@ def _check_errors(self) -> None:
if exc and isinstance(exc, Exception):
self._stored_exception = exc

def _cleanup_tasks(self) -> None:
if self._listener_task and not self._listener_task.done():
self._listener_task.cancel()
async def _cleanup_tasks(self) -> None:
current_task = asyncio.current_task()
tasks_to_await: set[asyncio.Task[Any]] = set()

if self._process_events_task and not self._process_events_task.done():
self._process_events_task.cancel()
while True:
while True:
tasks = [
task
for task in (
self._listener_task,
self._process_events_task,
self._stream_audio_task,
self._connection_task,
)
if task is not None and task is not current_task and task not in tasks_to_await
]
if not tasks:
break

for task in tasks:
if not task.done():
task.cancel()

tasks_to_await.update(tasks)
await asyncio.sleep(0)

if self._stream_audio_task and not self._stream_audio_task.done():
self._stream_audio_task.cancel()
if not tasks_to_await:
return

if self._connection_task and not self._connection_task.done():
self._connection_task.cancel()
for task in tasks_to_await:
if not task.done():
task.cancel()

await asyncio.gather(*tasks_to_await, return_exceptions=True)

if all(
task is None or task is current_task or task in tasks_to_await
for task in (
self._listener_task,
self._process_events_task,
self._stream_audio_task,
self._connection_task,
)
):
return

async def transcribe_turns(self) -> AsyncIterator[str]:
self._connection_task = asyncio.create_task(self._process_websocket_connection())
Expand Down Expand Up @@ -367,7 +400,7 @@ async def close(self) -> None:
if self._websocket:
await self._websocket.close()

self._cleanup_tasks()
await self._cleanup_tasks()


class OpenAISTTModel(STTModel):
Expand Down
36 changes: 36 additions & 0 deletions tests/voice/test_openai_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,42 @@ async def test_stream_audio_sends_correct_json():
await session.close()


@pytest.mark.asyncio
async def test_close_awaits_cancelled_background_tasks():
input_audio = await FakeStreamedAudioInput.get(count=0)
stt_settings = STTModelSettings()
session = OpenAISTTTranscriptionSession(
input=input_audio,
client=AsyncMock(api_key="FAKE_KEY"),
model="whisper-1",
settings=stt_settings,
trace_include_sensitive_data=False,
trace_include_sensitive_audio_data=False,
)

cleanup_events = [asyncio.Event() for _ in range(4)]

async def wait_until_cancelled(cleanup_event: asyncio.Event) -> None:
try:
await asyncio.Event().wait()
finally:
cleanup_event.set()

tasks = [asyncio.create_task(wait_until_cancelled(event)) for event in cleanup_events]
(
session._listener_task,
session._process_events_task,
session._stream_audio_task,
session._connection_task,
) = tasks

await asyncio.sleep(0)
await session.close()

assert all(task.done() for task in tasks)
assert all(event.is_set() for event in cleanup_events)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"created,updated,completed",
Expand Down