diff --git a/pyproject.toml b/pyproject.toml index 9e0987b6b..a6837b3d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ dev = [ "opentelemetry-exporter-otlp-proto-grpc>=1.11.1,<2", "opentelemetry-semantic-conventions>=0.40b0,<1", "opentelemetry-sdk-extension-aws>=2.0.0,<3", + "async-timeout>=4.0,<6; python_version < '3.11'", ] [tool.poe.tasks] diff --git a/temporalio/contrib/google_adk_agents/_model.py b/temporalio/contrib/google_adk_agents/_model.py index 8b32a7432..afc5e2f70 100644 --- a/temporalio/contrib/google_adk_agents/_model.py +++ b/temporalio/contrib/google_adk_agents/_model.py @@ -7,6 +7,7 @@ import temporalio.workflow from temporalio import activity, workflow +from temporalio.contrib.pubsub import PubSubClient from temporalio.workflow import ActivityConfig @@ -36,6 +37,54 @@ async def invoke_model(llm_request: LlmRequest) -> list[LlmResponse]: ] +@activity.defn +async def invoke_model_streaming( + llm_request: LlmRequest, + streaming_event_topic: str | None, + streaming_event_batch_interval: timedelta, +) -> list[LlmResponse]: + """Streaming-aware model activity. + + Calls the LLM with ``stream=True`` and returns the collected list of + raw ``LlmResponse`` chunks. The workflow's ``TemporalModel.generate_content_async`` + yields these to the caller. + + When ``streaming_event_topic`` is set, each response is also + published to the workflow's pub/sub broker so external consumers + (UIs, tracing, etc.) can observe responses as they arrive. Set the + topic to ``None`` to skip publishing entirely; in that case no + :class:`PubSubClient` is constructed. + """ + if llm_request.model is None: + raise ValueError("No model name provided, could not create LLM.") + + llm = LLMRegistry.new_llm(llm_request.model) + if not llm: + raise ValueError(f"Failed to create LLM for model: {llm_request.model}") + + responses: list[LlmResponse] = [] + + async def consume(pubsub: PubSubClient | None, topic: str | None) -> None: + async for response in llm.generate_content_async( + llm_request=llm_request, stream=True + ): + activity.heartbeat() + responses.append(response) + if pubsub is not None and topic is not None: + pubsub.publish(topic, response) + + if streaming_event_topic is None: + await consume(None, None) + else: + pubsub = PubSubClient.from_activity( + batch_interval=streaming_event_batch_interval, + ) + async with pubsub: + await consume(pubsub, streaming_event_topic) + + return responses + + class TemporalModel(BaseLlm): """A Temporal-based LLM model that executes model invocations as activities.""" @@ -45,9 +94,15 @@ def __init__( activity_config: ActivityConfig | None = None, *, summary_fn: Callable[[LlmRequest], str | None] | None = None, + streaming_event_topic: str | None = "events", + streaming_event_batch_interval: timedelta = timedelta(milliseconds=100), ) -> None: """Initialize the TemporalModel. + Streaming is selected by the caller via the ADK + ``generate_content_async(stream=True)`` argument; no plugin-level + flag is needed. + Args: model_name: The name of the model to use. activity_config: Configuration options for the activity execution. @@ -56,6 +111,18 @@ def __init__( deterministic as it is called during workflow execution. If the callable raises, the exception will propagate and fail the workflow task. + streaming_event_topic: Pub/sub topic to publish raw + ``LlmResponse`` chunks to when streaming. Set to ``None`` + to skip publishing entirely (workflow-side iteration via + ``stream=True`` still works, no broker required). When + set, the workflow must host a + :class:`temporalio.contrib.pubsub.PubSub` broker to + receive the publishes; otherwise the signals are + unhandled and dropped. + streaming_event_batch_interval: Interval between automatic + flushes for the pub/sub publisher used by the streaming + activity. Ignored when ``streaming_event_topic`` is + ``None``. Raises: ValueError: If both ``ActivityConfig["summary"]`` and ``summary_fn`` are set. @@ -63,6 +130,8 @@ def __init__( super().__init__(model=model_name) self._model_name = model_name self._summary_fn = summary_fn + self._streaming_event_topic = streaming_event_topic + self._streaming_event_batch_interval = streaming_event_batch_interval self._activity_config = ActivityConfig( start_to_close_timeout=timedelta(seconds=60) ) @@ -80,7 +149,9 @@ async def generate_content_async( Args: llm_request: The LLM request containing model parameters and content. - stream: Whether to stream the response (currently ignored). + stream: Whether to use the streaming activity. When ``True``, + each chunk is also published to ``streaming_event_topic`` + (if set) for external consumers. Yields: The responses from the model. @@ -103,10 +174,22 @@ async def generate_content_async( agent_name = llm_request.config.labels.get("adk_agent_name") if agent_name: config["summary"] = agent_name - responses = await workflow.execute_activity( - invoke_model, - args=[llm_request], - **config, - ) + + if stream: + responses = await workflow.execute_activity( + invoke_model_streaming, + args=[ + llm_request, + self._streaming_event_topic, + self._streaming_event_batch_interval, + ], + **config, + ) + else: + responses = await workflow.execute_activity( + invoke_model, + args=[llm_request], + **config, + ) for response in responses: yield response diff --git a/temporalio/contrib/google_adk_agents/_plugin.py b/temporalio/contrib/google_adk_agents/_plugin.py index 9be321398..7344485c8 100644 --- a/temporalio/contrib/google_adk_agents/_plugin.py +++ b/temporalio/contrib/google_adk_agents/_plugin.py @@ -3,12 +3,16 @@ import dataclasses import time import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager +from typing import Any from temporalio import workflow from temporalio.contrib.google_adk_agents._mcp import TemporalMcpToolSetProvider -from temporalio.contrib.google_adk_agents._model import invoke_model +from temporalio.contrib.google_adk_agents._model import ( + invoke_model, + invoke_model_streaming, +) from temporalio.contrib.pydantic import ( PydanticPayloadConverter, ToJsonOptions, @@ -95,7 +99,13 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: ) return runner - new_activities = [invoke_model] + # Annotate as Sequence[Callable[..., Any]] because invoke_model + # and invoke_model_streaming have different signatures, so the + # inferred list type would not satisfy SimplePlugin's parameter. + new_activities: list[Callable[..., Any]] = [ + invoke_model, + invoke_model_streaming, + ] if toolset_providers is not None: for toolset_provider in toolset_providers: new_activities.extend(toolset_provider._get_activities()) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index cffd8855e..f987d9314 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -4,9 +4,10 @@ """ import enum +import logging from dataclasses import dataclass from datetime import timedelta -from typing import Any +from typing import Any, NoReturn from agents import ( AgentOutputSchemaBase, @@ -27,6 +28,7 @@ UserError, WebSearchTool, ) +from agents.items import TResponseStreamEvent from agents.tool import ( ApplyPatchTool, LocalShellTool, @@ -43,8 +45,11 @@ from temporalio import activity from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater +from temporalio.contrib.pubsub import PubSubClient from temporalio.exceptions import ApplicationError +logger = logging.getLogger(__name__) + @dataclass class HandoffInput: @@ -185,6 +190,111 @@ class ActivityModelInput(TypedDict, total=False): previous_response_id: str | None conversation_id: str | None prompt: Any | None + streaming_event_topic: str | None + streaming_event_batch_interval: timedelta + + +async def _empty_on_invoke_tool(_ctx: RunContextWrapper[Any], _input: str) -> str: + return "" + + +async def _empty_on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Any: + return None + + +async def _noop_shell_executor(*_a: Any, **_kw: Any) -> str: + return "" + + +def _build_tool(tool: ToolInput) -> Tool: + """Reconstruct a Tool from its data-conversion-friendly input form.""" + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + LocalShellTool, + ToolSearchTool, + ), + ): + return tool + elif isinstance(tool, ShellToolInput): + return ShellTool( + name=tool.name, + environment=tool.environment, + executor=_noop_shell_executor, + ) + elif isinstance(tool, ApplyPatchToolInput): + return ApplyPatchTool(name=tool.name, editor=_NoopApplyPatchEditor()) + elif isinstance(tool, HostedMCPToolInput): + return HostedMCPTool(tool_config=tool.tool_config) + elif isinstance(tool, FunctionToolInput): + return FunctionTool( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + on_invoke_tool=_empty_on_invoke_tool, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise UserError(f"Unknown tool type: {tool.name}") # type:ignore[reportUnreachable] + + +def _build_tools_and_handoffs( + input: ActivityModelInput, +) -> tuple[list[Tool], list[Handoff[Any, Any]]]: + tools = [_build_tool(x) for x in input.get("tools", [])] + handoffs: list[Handoff[Any, Any]] = [ + Handoff( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + on_invoke_handoff=_empty_on_invoke_handoff, + ) + for x in input.get("handoffs", []) + ] + return tools, handoffs + + +def _raise_for_openai_status(e: APIStatusError) -> NoReturn: + """Translate an OpenAI APIStatusError into the right retry posture.""" + retry_after: timedelta | None = None + retry_after_ms_header = e.response.headers.get("retry-after-ms") + if retry_after_ms_header is not None: + retry_after = timedelta(milliseconds=float(retry_after_ms_header)) + + if retry_after is None: + retry_after_header = e.response.headers.get("retry-after") + if retry_after_header is not None: + retry_after = timedelta(seconds=float(retry_after_header)) + + should_retry_header = e.response.headers.get("x-should-retry") + if should_retry_header == "true": + raise e + if should_retry_header == "false": + raise ApplicationError( + "Non retryable OpenAI error", + non_retryable=True, + next_retry_delay=retry_after, + ) from e + + # Retry on 408 (Request Timeout), 409 (Conflict / often transient + # state mismatch), 429 (Too Many Requests / rate-limited), and any + # 5xx (server-side errors). All other 4xx codes are caller errors + # that won't recover on retry. + retryable = ( + e.response.status_code in [408, 409, 429] or e.response.status_code >= 500 + ) + raise ApplicationError( + f"{'Retryable' if retryable else 'Non retryable'} OpenAI status code: " + f"{e.response.status_code}", + non_retryable=not retryable, + next_retry_delay=retry_after, + ) from e class ModelActivity: @@ -203,72 +313,7 @@ def __init__(self, model_provider: ModelProvider | None = None): async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse: """Activity that invokes a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) - - async def empty_on_invoke_tool( - _ctx: RunContextWrapper[Any], _input: str - ) -> str: - return "" - - async def empty_on_invoke_handoff( - _ctx: RunContextWrapper[Any], _input: str - ) -> Any: - return None - - def make_tool(tool: ToolInput) -> Tool: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - LocalShellTool, - ToolSearchTool, - ), - ): - return tool - elif isinstance(tool, ShellToolInput): - - async def _noop_executor(*a: Any, **kw: Any) -> str: # type: ignore[reportUnusedParameter] - return "" - - return ShellTool( - name=tool.name, - environment=tool.environment, - executor=_noop_executor, - ) - elif isinstance(tool, ApplyPatchToolInput): - return ApplyPatchTool( - name=tool.name, - editor=_NoopApplyPatchEditor(), - ) - elif isinstance(tool, HostedMCPToolInput): - return HostedMCPTool( - tool_config=tool.tool_config, - ) - elif isinstance(tool, FunctionToolInput): - return FunctionTool( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - on_invoke_tool=empty_on_invoke_tool, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise UserError(f"Unknown tool type: {tool.name}") # type:ignore[reportUnreachable] - - tools = [make_tool(x) for x in input.get("tools", [])] - handoffs: list[Handoff[Any, Any]] = [ - Handoff( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - on_invoke_handoff=empty_on_invoke_handoff, - ) - for x in input.get("handoffs", []) - ] + tools, handoffs = _build_tools_and_handoffs(input) try: return await model.get_response( @@ -284,40 +329,67 @@ async def _noop_executor(*a: Any, **kw: Any) -> str: # type: ignore[reportUnuse prompt=input.get("prompt"), ) except APIStatusError as e: - # Listen to server hints - retry_after = None - retry_after_ms_header = e.response.headers.get("retry-after-ms") - if retry_after_ms_header is not None: - retry_after = timedelta(milliseconds=float(retry_after_ms_header)) - - if retry_after is None: - retry_after_header = e.response.headers.get("retry-after") - if retry_after_header is not None: - retry_after = timedelta(seconds=float(retry_after_header)) - - should_retry_header = e.response.headers.get("x-should-retry") - if should_retry_header == "true": - raise e - if should_retry_header == "false": - raise ApplicationError( - "Non retryable OpenAI error", - non_retryable=True, - next_retry_delay=retry_after, - ) from e - - # Specifically retryable status codes - if ( - e.response.status_code in [408, 409, 429] - or e.response.status_code >= 500 - ): - raise ApplicationError( - f"Retryable OpenAI status code: {e.response.status_code}", - non_retryable=False, - next_retry_delay=retry_after, - ) from e - - raise ApplicationError( - f"Non retryable OpenAI status code: {e.response.status_code}", - non_retryable=True, - next_retry_delay=retry_after, - ) from e + _raise_for_openai_status(e) + + @activity.defn + @_auto_heartbeater + async def invoke_model_activity_streaming( + self, input: ActivityModelInput + ) -> list[TResponseStreamEvent]: + """Streaming-aware model activity. + + Calls ``model.stream_response()`` and returns the collected list + of native OpenAI stream events. The workflow's + ``Model.stream_response`` stub yields these to the agents + framework, which builds the final ``ModelResponse`` from the + terminal ``ResponseCompletedEvent``. + + When ``streaming_event_topic`` is set, each event is also + published to the workflow's pub/sub broker so external consumers + (UIs, tracing, etc.) can observe events as they arrive. Set the + topic to ``None`` to skip publishing entirely; in that case no + :class:`PubSubClient` is constructed. + + Heartbeats run on a background task via ``_auto_heartbeater`` so + long initial-token latency or long pauses between chunks do not + trip ``heartbeat_timeout``. + """ + model = self._model_provider.get_model(input.get("model_name")) + tools, handoffs = _build_tools_and_handoffs(input) + + topic = input.get("streaming_event_topic") + events: list[TResponseStreamEvent] = [] + + async def consume(pubsub: PubSubClient | None, topic: str | None) -> None: + try: + async for event in model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ): + events.append(event) + if pubsub is not None and topic is not None: + pubsub.publish(topic, event) + except APIStatusError as e: + _raise_for_openai_status(e) + + if topic is None: + await consume(None, None) + else: + batch_interval = input.get( + "streaming_event_batch_interval", timedelta(milliseconds=100) + ) + pubsub = PubSubClient.from_activity( + batch_interval=batch_interval, + ) + async with pubsub: + await consume(pubsub, topic) + + return events diff --git a/temporalio/contrib/openai_agents/_model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py index 55827e0d5..c0e6f4dfc 100644 --- a/temporalio/contrib/openai_agents/_model_parameters.py +++ b/temporalio/contrib/openai_agents/_model_parameters.py @@ -68,3 +68,19 @@ class ModelActivityParameters: use_local_activity: bool = False """Whether to use a local activity. If changed during a workflow execution, that would break determinism.""" + + streaming_event_topic: str | None = "events" + """Pub/sub topic to publish raw model stream events to when the workflow + calls ``Runner.run_streamed``. Set to ``None`` to skip publishing + entirely (workflow-side iteration via ``stream_events()`` still works, + no broker required). When set, the workflow must host a + :class:`temporalio.contrib.pubsub.PubSub` broker to receive the + publishes; otherwise the signals are unhandled and dropped. + + Streaming is incompatible with ``use_local_activity`` (local activities + do not support heartbeats or the pubsub signal channel).""" + + streaming_event_batch_interval: timedelta = timedelta(milliseconds=100) + """Interval between automatic flushes for the pub/sub publisher used + by the streaming activity. Ignored when ``streaming_event_topic`` is + ``None``.""" diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index 1884ff8a6..8307619f4 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -1,3 +1,4 @@ +import asyncio import dataclasses from collections.abc import Awaitable from typing import Any, Callable @@ -243,14 +244,123 @@ def run_streamed( input: str | list[TResponseInputItem] | RunState[TContext], **kwargs: Any, ) -> RunResultStreaming: - """Run the agent with streaming responses (not supported in Temporal workflows).""" + """Run the agent with streaming responses. + + Inside a workflow, model calls execute as the streaming model + activity. The workflow consumes events via + ``RunResultStreaming.stream_events()`` after each activity + completes; external clients can subscribe to the configured + pub/sub topic to receive events as they arrive. + """ if not workflow.in_workflow(): return self._runner.run_streamed( starting_agent, input, **kwargs, ) - raise RuntimeError("Temporal workflows do not support streaming.") + + for t in starting_agent.tools: + if callable(t): + raise ValueError( + "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." + ) + + if starting_agent.mcp_servers: + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + _StatelessMCPServerReference, + ) + + for s in starting_agent.mcp_servers: + if not isinstance( + s, + ( + _StatelessMCPServerReference, + _StatefulMCPServerReference, + ), + ): + raise ValueError( + f"Unknown mcp_server type {type(s)} may not work durably." + ) + + run_config = kwargs.get("run_config") + session = kwargs.get("session") + + if isinstance(session, SQLiteSession): + raise ValueError("Temporal workflows don't support SQLite sessions.") + + if run_config is None: + run_config = RunConfig() + kwargs["run_config"] = run_config + + if run_config.model and not isinstance(run_config.model, _TemporalModelStub): + if not isinstance(run_config.model, str): + raise ValueError( + "Temporal workflows require a model name to be a string in the run config." + ) + run_config = dataclasses.replace( + run_config, + model=_TemporalModelStub( + run_config.model, model_params=self.model_params, agent=None + ), + ) + kwargs["run_config"] = run_config + + if _has_sandbox_agent(starting_agent) or run_config.sandbox: + if run_config.sandbox is None: + raise ValueError( + "A SandboxAgent was provided but run_config.sandbox is not configured. " + "You must set run_config.sandbox to a SandboxRunConfig. " + "For example:\n" + " from temporalio.contrib.openai_agents.workflow import temporal_sandbox_client\n" + " run_config = RunConfig(sandbox=SandboxRunConfig(client=temporal_sandbox_client('my-backend')))" + ) + elif run_config.sandbox.client is None: + raise ValueError( + "run_config.sandbox.client must be set to a temporal sandbox client. " + "Use temporalio.contrib.openai_agents.workflow.temporal_sandbox_client(name) " + "to create one, where name matches a SandboxClientProvider registered on the plugin." + ) + elif not isinstance(run_config.sandbox.client, TemporalSandboxClient): + raise ValueError( + "run_config.sandbox.client must be created via " + "temporalio.contrib.openai_agents.workflow.temporal_sandbox_client(name). " + "Do not pass a raw sandbox client directly." + ) + + streamed_result = self._runner.run_streamed( + starting_agent=_convert_agent(self.model_params, starting_agent, None), + input=input, + **kwargs, + ) + + # Mirror the AgentsException -> AgentsWorkflowError rewrap done + # in run() above. The streaming runner attaches the actual run + # to ``run_loop_task``; wrap it so workflow-failure-bearing + # AgentsException instances surface as AgentsWorkflowError (the + # plugin registers that type in workflow_failure_exception_types, + # which is how durable failures propagate as terminal workflow + # failures rather than retrying workflow-task errors). + original_task = streamed_result.run_loop_task + if original_task is not None: + + async def _rewrap_agents_exception() -> Any: + try: + return await original_task + except AgentsException as e: + if e.__cause__ and workflow.is_failure_exception(e.__cause__): + reraise = AgentsWorkflowError( + f"Workflow failure exception in Agents Framework: {e}" + ) + reraise.__traceback__ = e.__traceback__ + raise reraise from e.__cause__ + raise + + streamed_result.run_loop_task = asyncio.create_task( + _rewrap_agents_exception() + ) + + return streamed_result def _model_name(agent: Agent[Any]) -> str | None: diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index 03e689f17..b6b47e261 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from datetime import timedelta from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters @@ -60,8 +61,9 @@ def __init__( self.model_params = model_params self.agent = agent - async def get_response( + def _build_activity_input( self, + *, system_instructions: str | None, input: str | list[TResponseInputItem], model_settings: ModelSettings, @@ -69,11 +71,10 @@ async def get_response( output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, - *, previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, - ) -> ModelResponse: + ) -> tuple[ActivityModelInput, str | None]: def make_tool_info(tool: Tool) -> ToolInput: if isinstance( tool, @@ -166,6 +167,35 @@ def make_tool_info(tool: Tool) -> ToolInput: else: summary = None + return activity_input, summary + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> ModelResponse: + activity_input, summary = self._build_activity_input( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + if self.model_params.use_local_activity: return await workflow.execute_local_activity_method( ModelActivity.invoke_model_activity, @@ -177,23 +207,22 @@ def make_tool_info(tool: Tool) -> ToolInput: retry_policy=self.model_params.retry_policy, cancellation_type=self.model_params.cancellation_type, ) - else: - return await workflow.execute_activity_method( - ModelActivity.invoke_model_activity, - activity_input, - summary=summary, - task_queue=self.model_params.task_queue, - schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, - schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, - start_to_close_timeout=self.model_params.start_to_close_timeout, - heartbeat_timeout=self.model_params.heartbeat_timeout, - retry_policy=self.model_params.retry_policy, - cancellation_type=self.model_params.cancellation_type, - versioning_intent=self.model_params.versioning_intent, - priority=self.model_params.priority, - ) + return await workflow.execute_activity_method( + ModelActivity.invoke_model_activity, + activity_input, + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) - def stream_response( + async def stream_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], @@ -207,4 +236,52 @@ def stream_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: - raise NotImplementedError("Temporal model doesn't support streams yet") + # Streaming relies on activity heartbeats to detect a stuck LLM + # call and on PubSubClient.from_activity() to signal partial + # results back to the workflow. Local activities support + # neither: their result commits with the workflow task, so there + # is no independent task to heartbeat against or to send signals + # from. + if self.model_params.use_local_activity: + raise ValueError( + "Streaming is incompatible with use_local_activity " + "(local activities do not support heartbeats or the " + "pubsub signal channel)." + ) + + activity_input, summary = self._build_activity_input( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + activity_input["streaming_event_topic"] = ( + self.model_params.streaming_event_topic + ) + activity_input["streaming_event_batch_interval"] = ( + self.model_params.streaming_event_batch_interval + ) + + events = await workflow.execute_activity_method( + ModelActivity.invoke_model_activity_streaming, + activity_input, + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout + or timedelta(seconds=30), + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) + for event in events: + yield event diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index f7757723c..60b4b36ef 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -205,7 +205,11 @@ def add_activities( if not register_activities: return activities or [] - new_activities = [ModelActivity(model_provider).invoke_model_activity] + model_activity = ModelActivity(model_provider) + new_activities = [ + model_activity.invoke_model_activity, + model_activity.invoke_model_activity_streaming, + ] server_names = [server.name for server in mcp_server_providers] if len(server_names) != len(set(server_names)): diff --git a/temporalio/contrib/pubsub/DESIGN-v2.md b/temporalio/contrib/pubsub/DESIGN-v2.md new file mode 100644 index 000000000..37f928d83 --- /dev/null +++ b/temporalio/contrib/pubsub/DESIGN-v2.md @@ -0,0 +1,1352 @@ +# Temporal Workflow Pub/Sub — Design Document v2 + +Consolidated design document reflecting the current implementation. + +> The Python code in `sdk-python/temporalio/contrib/pubsub/` is authoritative. +> Both this document and the Notion page +> ["Streaming API Design Considerations"](https://www.notion.so/3478fc567738803d9c22eeb64a296e21) +> track it. When API or wire-format facts change in code, update this doc in +> the same commit and mirror to Notion. When new narrative (a decision, a +> comparison) lands in either doc, port it to the other before the next +> review cycle. + +## Overview + +A reusable pub/sub module for Temporal workflows. The workflow acts as the +message broker — it holds an append-only log of `(topic, data)` entries. +External clients (activities, starters, other services) publish and subscribe +through the workflow handle using Temporal primitives (signals, updates, +queries). + +The module ships as `temporalio.contrib.pubsub` in the Python SDK and is +designed to be cross-language compatible. Payloads are opaque byte strings — +the workflow does not interpret them. + +## Architecture + +``` + ┌──────────────────────────────────┐ + │ Temporal Workflow │ + │ (PubSub broker) │ + │ │ + │ ┌────────────────────────────┐ │ + │ │ Append-only log │ │ + │ │ [(topic, data), ...] │ │ + │ │ base_offset: int │ │ + │ │ publisher_sequences: {} │ │ + │ └────────────────────────────┘ │ + │ │ + signal ──────────►│ __temporal_pubsub_publish (with dedup) │ + update ──────────►│ __temporal_pubsub_poll (long-poll) │◄── subscribe() + query ──────────►│ __temporal_pubsub_offset │ + │ │ + │ publish() ── workflow-side │ + └──────────────────────────────────┘ + │ + │ continue-as-new + ▼ + ┌──────────────────────────────────┐ + │ PubSubState carries: │ + │ log, base_offset, │ + │ publisher_sequences │ + └──────────────────────────────────┘ +``` + +## API Surface + +### Workflow side — `PubSub` + +A helper class instantiated from `@workflow.init`. Its constructor +registers the pub/sub signal, update, and query handlers on the current +workflow via `workflow.set_signal_handler`, `workflow.set_update_handler`, +and `workflow.set_query_handler` — there is no base class to inherit. +This matches how other-language SDKs will express the same pattern +(imperative handler registration from inside the workflow body). + +```python +from dataclasses import dataclass +from temporalio import workflow +from temporalio.contrib.pubsub import PubSub, PubSubState + +@dataclass +class MyInput: + pubsub_state: PubSubState | None = None + +@workflow.defn +class MyWorkflow: + @workflow.init + def __init__(self, input: MyInput) -> None: + self.pubsub = PubSub(prior_state=input.pubsub_state) + + @workflow.run + async def run(self, input: MyInput) -> None: + self.pubsub.publish("status", b"started") + await do_work() + self.pubsub.publish("status", b"done") +``` + +Construct `PubSub(...)` once from `@workflow.init`. Include a +`PubSubState | None` field on your workflow input and always pass it as +`prior_state`: it is `None` on fresh starts and carries accumulated +state across continue-as-new (see [Continue-as-New](#continue-as-new)). +Workflows that will never continue-as-new may call `PubSub()` with no +argument. Instantiating `PubSub` twice on the same workflow raises +`RuntimeError`, detected via `workflow.get_signal_handler("__temporal_pubsub_publish")`. + +| Method / Handler | Kind | Description | +|---|---|---| +| `PubSub(prior_state=None)` | constructor | Initialize internal state and register handlers on the current workflow. Must be called from `@workflow.init`. | +| `publish(topic, value)` | instance method | Append to the log from workflow code. `value` is converted via the workflow's sync payload converter (no codec). | +| `get_state(publisher_ttl=timedelta(seconds=900))` | instance method | Snapshot for CAN. Prunes dedup entries older than TTL. | +| `drain()` | instance method | Unblock polls and reject new ones for CAN. | +| `continue_as_new(build_args, *, publisher_ttl=timedelta(seconds=900))` | async instance method | Drain, wait for handlers, then `workflow.continue_as_new` with `build_args(post_drain_state)`. | +| `truncate(up_to_offset)` | instance method | Discard log entries before offset. | +| `__temporal_pubsub_publish` | signal handler | Receives publications from external clients (with dedup). | +| `__temporal_pubsub_poll` | update handler | Long-poll subscription: blocks until new items or drain. | +| `__temporal_pubsub_offset` | query handler | Returns the current global offset. | + +### Client side — `PubSubClient` + +Used by activities, starters, and any code with a workflow handle. + +```python +from temporalio.contrib.pubsub import PubSubClient + +# Preferred: factory method (enables CAN following + activity auto-detect) +client = PubSubClient.create(temporal_client, workflow_id) + +# --- Publishing (with batching) --- +# Values go through the client's data converter — including the codec +# chain (encryption, PII-redaction, compression) — per item. +async with client: + client.publish("events", TextDelta(delta="hello")) + client.publish("events", TextDelta(delta=" world")) + client.publish("events", TextComplete(), force_flush=True) + client.publish("raw", my_prebuilt_payload) # zero-copy fast path + +# --- Subscribing --- +# Pass result_type=T to have item.data decoded to T via the same codec +# chain. Without result_type, item.data is the raw Payload and the +# caller dispatches on metadata. +async for item in client.subscribe(["events"], result_type=EventUnion): + print(item.topic, item.data) + if is_done(item): + break +``` + +| Method | Description | +|---|---| +| `PubSubClient.create(client, wf_id)` | Factory with explicit Temporal client and workflow id. Follows CAN in `subscribe()`. | +| `PubSubClient.from_activity()` | Factory that pulls client and workflow id from the current activity context. Follows CAN in `subscribe()`. | +| `PubSubClient(handle)` | From handle directly (no CAN following; no codec chain — falls back to the default converter). | +| `publish(topic, value, force_flush=False)` | Buffer a message. `value` may be any converter-compatible object or a pre-built `Payload`. `force_flush` triggers immediate flush (fire-and-forget). | +| `flush()` | Async. Block until items buffered at call time are confirmed by the server. No-op if nothing is buffered. | +| `subscribe(topics, from_offset, *, result_type=None, poll_cooldown=timedelta(milliseconds=100))` | Async iterator. `result_type` decodes `item.data` to the given type; omit for raw `Payload`. Always follows CAN chains when created via `create` or `from_activity`. | +| `get_offset()` | Query current global offset. | + +The client offers three complementary ways to flush: + +1. **Context manager exit** — drains and flushes on `__aexit__`. Best + when the publisher's lifetime maps cleanly to a scope. +2. **`force_flush=True` on `publish()`** — declarative, fire-and-forget. + Best when the *event being published* is itself the signal to flush + (e.g. a "stream complete" sentinel). +3. **`await client.flush()`** — explicit synchronization point that + returns once buffered items have been acknowledged by the server. + Best when the caller needs proof that prior publications landed but + the moment does not correspond to any particular event — e.g. + "before returning from this activity, make sure everything I have + published is durable." + +#### Activity convenience + +Inside an activity, use `PubSubClient.from_activity()` — the Temporal +client and target workflow id come from the activity context, so the +caller doesn't have to thread them through: + +```python +@activity.defn +async def stream_events() -> None: + client = PubSubClient.from_activity(batch_interval=timedelta(seconds=2)) + async with client: + for chunk in generate_chunks(): + client.publish("events", chunk) + activity.heartbeat() +``` + +`from_activity()` is a separate factory rather than an overload of +`create()` because silently inferring arguments outside an activity +masks a configuration bug as a runtime error in an unrelated code +path. + +## Data Types + +```python +from temporalio.api.common.v1 import Payload + +@dataclass +class PubSubItem: + topic: str + data: Any # Payload by default; decoded value when + # subscribe is called with result_type=T + offset: int = 0 # Populated at poll time + +@dataclass +class PublishEntry: + topic: str + data: str # Wire: base64(Payload.SerializeToString()) + +@dataclass +class PublishInput: + items: list[PublishEntry] + publisher_id: str = "" # For exactly-once dedup + sequence: int = 0 # Monotonic per publisher + +@dataclass +class PollInput: + topics: list[str] # Filter (empty = all) + from_offset: int = 0 # Global offset to resume from + +@dataclass +class PollResult: + items: list[_WireItem] # Wire-format items + next_offset: int = 0 # Offset for next poll + more_ready: bool = False # Truncated response; poll again + +@dataclass +class PubSubState: + log: list[_WireItem] = field(default_factory=list) + base_offset: int = 0 + publisher_sequences: dict[str, int] = field(default_factory=dict) + publisher_last_seen: dict[str, float] = field(default_factory=dict) +``` + +The containing workflow input must type the field as `PubSubState | None`, +not `Any` — `Any`-typed fields deserialize as plain dicts, losing the type. + +### Wire format for payloads + +The user-facing `data` on `PubSubItem` is a +`temporalio.api.common.v1.Payload`, which carries both the data bytes +and the encoding metadata written by the client's data converter and +codec chain. Subscribers can either decode by passing `result_type=T` +to `subscribe()` (runs the async converter chain, including the codec) +or inspect `Payload.metadata` directly for heterogeneous topics. + +On the wire, every `data` string is +`base64(Payload.SerializeToString())`. This is because the default +JSON payload converter can serialize a top-level `Payload` as a +signal argument but **cannot** serialize a `Payload` embedded inside +a dataclass (it raises `TypeError: Object of type Payload is not JSON +serializable`). Embedding the proto-serialized bytes keeps the wire +format JSON-compatible while preserving the full `Payload` — metadata +and all — across the signal and update round-trips. Round-trip is +guarded by +`tests/contrib/pubsub/test_payload_roundtrip_prototype.py`. + +## Design Decisions + +### 1. Durable streams + +All stream events flow through the workflow's append-only log, backed by +Temporal's persistence layer. There is no ephemeral streaming option. + +**Trade-off.** Ephemeral streams that skip the Temporal server, or transit it +with lower durability, would be less resource-intensive. We chose durable +streams because: + +1. **Simpler programming model.** One event path, one source of truth. The + application does not need merge logic, reconnection handling for a second + channel, or fallback behavior when the ephemeral path fails. +2. **Reliability.** Events survive worker crashes, workflow restarts, and + continue-as-new. A subscriber that connects after a failure sees the + complete history, not a gap where the ephemeral channel lost events. +3. **Correctness.** With a single path, subscriber code is the same whether + processing events live or replaying them after a reconnect. A separate + ephemeral path for latency-sensitive events (e.g., token deltas) would + create a second code path through the frontend — additional complexity + that is difficult to test. + +The cost is latency: events round-trip through the Temporal server before +reaching the subscriber. Batching (see [Batching is built into the +client](#7-batching-is-built-into-the-client)) manages this — a 0.1-second +interval for token streaming keeps latency acceptable while amortizing +per-signal overhead. + +Durability is Temporal's core value proposition. Making the stream durable by +default aligns with the platform. + +### 2. Topics are plain strings, no hierarchy + +Topics are exact-match strings. No prefix matching, no wildcards. A subscriber +provides a list of topic strings to filter on; an empty list means "all topics." + +### 3. Items are Temporal `Payload`s, not opaque bytes + +The workflow stores each item as a +`temporalio.api.common.v1.Payload` — the same type signals, updates, +and activities use. Publishers pass any value the client's data +converter accepts (or a pre-built `Payload` for zero-copy); +subscribers either receive the raw `Payload` (for heterogeneous +topics) or pass `result_type=T` to have it decoded. + +This replaces an earlier "opaque byte strings" design. We switched +because the opaque-bytes path **skipped the user's codec chain** — +encryption, PII-redaction, and compression codecs saw only the +outer `PublishInput` envelope, not the individual items. For users +who expect their codec chain to cover every piece of data flowing +through Temporal, that is a silent compliance/correctness gap. + +The three original arguments for opaque bytes don't hold up: + +1. **Decoupling from the data converter.** Signals and updates + accept `Any` without making handlers generic; `Payload.metadata` + carries per-value encoding info. Pub/sub can do the same. +2. **Layering — transport vs. application.** Every other Temporal + API surface (signals, updates, activity args, workflow args) + uses `Payload`. Pub/sub was the outlier. +3. **Type hints at decode time.** Subscribers pass `result_type` at + the subscribe boundary — the same pattern as + `execute_update(result_type=...)`. + +**Codec runs once, at the envelope level.** Both +`PubSubClient.publish` and `PubSub.publish` turn values into +`Payload` via the **sync** payload converter. The codec chain is +not applied per item. It runs once — on the `__temporal_pubsub_publish` +signal envelope (client → workflow path) and on the +`__temporal_pubsub_poll` update envelope (workflow → subscriber path) — +because Temporal's SDK already runs `DataConverter.encode` on +signal and update args. Running the codec per item *as well* +would double-encrypt / double-compress, and compressing +already-encrypted data is pointless. The per-item `Payload` still +carries the encoding metadata (`encoding: json/plain`, +`messageType`, etc.), so `subscribe(result_type=T)` works +without needing the codec to have run per item. + +**Wire format.** `PublishEntry.data` and `_WireItem.data` are +base64-encoded `Payload.SerializeToString()` bytes, not nested +`Payload` protos, because the default JSON converter cannot +serialize a `Payload` embedded inside a dataclass. See [Data +Types — Wire format for payloads](#wire-format-for-payloads). + +### 4. Global offsets, NATS JetStream model + +> 🚪 **One-way door.** Once subscribers persist and resume from global integer +> offsets — stored in SSE `Last-Event-ID`, BFF reconnection state, and +> client-side cursor logic — the offset semantics are baked into the wire +> protocol. Switching to per-topic offsets later would break every existing +> subscriber's resume path. This is the right choice (cursor portability and +> cross-topic ordering are valuable), but recognize that every consumer built +> against this API will assume a single integer is a complete stream position. + +Every entry gets a global offset from a single counter. Subscribers filter by +topic but advance through the global offset space. + +We surveyed offset models across Kafka, Redis Streams, NATS JetStream, PubNub, +Google Pub/Sub, RabbitMQ Streams, and Amazon SQS/SNS. No major system provides +a true global offset across independent topics. The two closest: + +- **NATS JetStream**: one stream captures multiple subjects via wildcards, with + a single sequence counter. This is our model. +- **PubNub**: wall-clock nanosecond timestamp as cursor across channels. + +We evaluated six alternatives for handling the information leakage that global +offsets create (a single-topic subscriber can infer other-topic activity from +gaps): per-topic counts, opaque cursors, encrypted cursors, per-topic lists, +per-topic offsets with cursor hints, and accepting the leakage. + +| Option | Systems | Leakage | Cross-topic ordering | Resume cost | Cursor portability | +|---|---|---|---|---|---| +| Per-topic count as cursor | *(theoretical)* | None | Preserved | O(n) or extra state | Coupled to filter | +| Opaque cursor wrapping global offset | *(theoretical)* | Observable | Preserved | O(1) | Filter-independent | +| Encrypted global offset | *(theoretical)* | None | Preserved | O(1) | Filter-independent | +| Per-topic / per-partition lists | Kafka, Redis Streams, RabbitMQ Streams, Google Pub/Sub, SQS/SNS | None | **Lost** | O(1) | N/A | +| **Global offsets (chosen)** | NATS JetStream, PubNub (timestamp variant) | Contained at BFF | Preserved | O(new items) | Filter-independent | +| Per-topic offsets with cursor hints | *(theoretical)* | None | Preserved | O(new items) | Per-topic only | + +**Decision:** Global offsets are the right choice for workflow-scoped pub/sub. + +**Why not per-topic offsets?** The most sophisticated alternative — per-topic +offsets with opaque cursors carrying global position hints — was rejected +for three reasons: + +1. **The threat model doesn't apply.** Information leakage assumes untrusted + multi-tenant subscribers who shouldn't learn about each other's traffic + volumes. That's Kafka's world — separate consumers for separate services. + In workflow-scoped pub/sub, the subscriber is the BFF: trusted server-side + code that could just as easily subscribe to all topics. + +2. **Cursor portability.** A global offset is a stream position that works + regardless of which topics you filter on. You can subscribe to `["events"]`, + then later subscribe to `["events", "thinking"]` with the same offset. + Per-topic cursors are coupled to the filter — you need a separate cursor per + topic, and adding a topic to your subscription requires starting it from the + beginning. + +3. **Unjustified complexity.** Per-topic cursors require cursor + parsing/formatting, a `topic_counts` dict that survives continue-as-new, a + multi-cursor alignment algorithm, and stale-hint fallback paths. For log + sizes of thousands of items where a filtered slice is microseconds, this + machinery adds cost without measurable benefit. + +**Leakage is contained at the BFF trust boundary.** The global offset stays +between workflow and BFF. The BFF assigns its own gapless SSE event IDs to the +browser. The global offset never reaches the end client. See +[Information Leakage and the BFF](#information-leakage-and-the-bff) for the +full mechanism. + +### 5. No topic creation + +Topics are implicit. Publishing to a topic creates it. Subscribing to a +nonexistent topic returns no items and waits for new ones. + +### 6. `force_flush` forces a flush, does not reorder + +`force_flush=True` causes the client to immediately flush its buffer. It +does NOT reorder items — the flushed item appears in its natural +position after any previously-buffered items. The purpose is +latency-sensitive delivery, not importance ranking. + +### 7. Session ordering + +Publications from a single client are ordered. This relies on two Temporal +guarantees: + +> "Signals are delivered in the order they are received by the Cluster and +> written to History." +> ([docs](https://docs.temporal.io/workflows#signal)) + +Specifically: (1) signals sent sequentially from the same client appear in +workflow history in send order, and (2) signal handlers are invoked in +history order. The guarantee breaks down only for *concurrent* signals — if +two signal RPCs are in flight simultaneously, their order in history is +nondeterministic. The `PubSubClient` flush lock (`_flush_lock`) ensures +signals are never in flight concurrently from a single client: + +1. Acquire lock +2. `await handle.signal(...)` — blocks until server writes to history +3. Release lock + +Combined with the workflow's single-threaded signal processing (the +`__temporal_pubsub_publish` handler is synchronous — no `await`), items within and +across batches from a single publisher preserve their publish order. + +Concurrent publishers get a total order in the log (the workflow serializes +all signal processing), but the interleaving is nondeterministic — it depends +on arrival order at the server. Per-publisher ordering is preserved. This is +formally verified as `OrderPreservedPerPublisher`. + +Once items are in the log, their order is stable — reads are repeatable. + +### 8. Batching is built into the client + +`PubSubClient` includes a Nagle-like batcher (buffer + timer). The async +context manager starts a background flush task; exiting cancels it and does a +final flush. Batching amortizes Temporal signal overhead. + +Parameters: +- `batch_interval` (`timedelta`, default 2 seconds): timer between automatic flushes. +- `max_batch_size` (optional): auto-flush when buffer reaches this size. + +### 9. Subscription is poll-based, exposed as async iterator + +The fundamental primitive is an offset-based long-poll: the subscriber sends +`from_offset` and gets back items plus `next_offset`. `__temporal_pubsub_poll` is a +Temporal update with `wait_condition`. `subscribe()` wraps this in an +`AsyncIterator` with a configurable `poll_cooldown` (`timedelta`, default +100ms) to rate-limit polls. + +**Trade-off.** The alternative is server-push — the pub/sub system executes +a callback on the subscriber. Pull is better aligned with durable streams: + +1. **Back-pressure is natural.** A slow subscriber just polls less + frequently. Push requires the server to implement flow control to avoid + overwhelming subscribers — or risk dropping messages, defeating the + durable-stream purpose. +2. **The subscriber controls its own read position.** It can replay from an + earlier offset, skip ahead, or resume from exactly where it left off. + Push requires the server to track per-subscriber delivery state. +3. **Durable streams are data at rest.** The log exists regardless of + whether anyone is reading it. Pull treats the log as something to read + from; push treats it as a pipe to deliver through, which fights the + durability model. + +Temporal's architecture reinforces this — there is no server-push mechanism +for external clients. Updates with `wait_condition` are the closest +approximation: the workflow blocks until data is available, making it +behave like push from the subscriber's perspective while remaining pull on +the wire. + +**Both layers are exposed.** The offset-based poll is a first-class part +of the API, not hidden behind the iterator. The BFF uses offsets directly +to map SSE event IDs to global offsets for reconnection. Application code +that just wants to process items in order uses the iterator. Different +consumers use different layers. + +**Poll efficiency.** The poll slices `self._log[from_offset - base_offset:]` +and filters by topic. The common case — single topic, continuing from last +poll — is O(new items since last poll). The global offset points directly to +the resume position with no scanning or cursor alignment. Multi-topic polls +are the same cost: one slice, one filter pass. The worst case is a poll from +offset 0 (full log scan), which only happens on first connection or after the +subscriber falls behind. + +**Fan-out is per-poll, not shared.** Each `__temporal_pubsub_poll` update is an +independent Temporal update RPC. The handler has no registry of active +subscribers; every call executes `_on_poll` from scratch with its own +`from_offset` closure and topic set. When a publish grows the log, +Temporal's `wait_condition` machinery re-evaluates every pending predicate +and wakes each one whose condition is now true. Each then slices the same +shared log independently, applies its own topic filter, and returns its own +`PollResult` on its own update response. + +The consequences: + +- Two subscribers on the same topics from the same offset both receive the + items — each item travels the wire **twice**, once per update response. +- Two subscribers from different offsets each see their own slice; the + overlapping range is serialized into both responses. +- Two subscribers with disjoint topics each see a filtered subset; no items + are duplicated across their responses, but the log is walked twice. + +This is deliberate. Temporal updates are 1:1 RPCs, not a shared delivery +fabric. There is no intra-workflow subscriber registry, no cross-poll +dedup, no broadcast. Fan-out cost scales linearly with subscriber count, +but there's no shared state between polls to get wrong and no delivery-order +ambiguity between them. Applications that need to multiplex a single +subscription across many local consumers should do so on the client side, +below the `subscribe()` iterator — one poll stream feeding N in-process +consumers. A workflow-side shared fan-out is listed under +[Future Work](#future-work). + +### 10. Workflow can publish but should not subscribe + +Workflow code can call `self.publish()` directly — this is deterministic. +Reading from the log within workflow code is possible but breaks the +failure-free abstraction because external publishers send data via signals +(non-deterministic inputs), and branching on signal content creates +replay-sensitive code paths. + +### 11. `base_offset` for truncation + +The log carries a `base_offset`. All offset arithmetic uses +`offset - base_offset` to index into the log, so discarding a prefix of +consumed entries and advancing `base_offset` keeps global offsets +monotonic. If a poll's `from_offset` is below `base_offset`, the +subscriber has fallen behind truncation and the poll fails with a +non-retryable `TruncatedOffset` error. + +Because the module targets continue-as-new as the standard pattern for +long-running workflows, workflow history size is not the binding +constraint — CAN rolls history forward indefinitely. The binding +constraint is the in-memory log growing between CAN boundaries. Voice +streaming workflows have shown this matters in practice: a session can +accumulate tens of thousands of small audio/text events long before CAN +is triggered, and the workflow needs a way to release entries the +subscriber has already consumed without waiting for a CAN cycle. +`PubSub.truncate(up_to_offset)` exposes this. + +### 12. No timeout on long-poll + +`wait_condition` in the poll handler has no timeout. The poll blocks +indefinitely until one of three things happens: + +1. **New data arrives** — the `len(log) > offset` condition fires. +2. **Draining for continue-as-new** — `PubSub.drain()` sets the flag. +3. **Client disconnects** — the BFF drops the SSE connection, cancels the + update RPC, and the handler becomes an inert coroutine cleaned up at + the next drain cycle. + +A previous design used a 5-minute timeout as a defensive "don't block +forever" mechanism. This was removed because: + +- **It adds unnecessary history events.** Every poll creates a `TimerStarted` + event. For a streaming session doing hundreds of polls, this doubles the + history event count and accelerates approach to the ~50K event CAN threshold. +- **The drain mechanism already handles cleanup.** `PubSub.drain()` unblocks + all waiting polls, and the update validator rejects new polls, so + `all_handlers_finished()` converges without timers. +- **Zombie polls are harmless.** If a client crashes without cancelling, its + poll handler is just an in-memory coroutine waiting on a condition. It + consumes no Temporal actions and is cleaned up at the next CAN cycle. + +### 13. Signals for publish, updates for poll + +Publishing uses signals (fire-and-forget); subscription uses updates +(request-response with `wait_condition`). These choices are deliberate. + +**Why signals for publish:** + +- **Non-blocking flush.** The activity can buffer tokens at whatever rate + the LLM produces them. `handle.signal(...)` enqueues at the server and + returns immediately — the publisher is never throttled by the workflow's + processing speed. +- **Lower history cost.** Each signal adds 1 event (`WorkflowSignalReceived`). + An update adds 2 (`UpdateAccepted` + `UpdateCompleted`). For a streaming + session with hundreds of flushes, signals halve the history growth rate and + delay the CAN threshold. +- **No concurrency limits.** Temporal Cloud enforces per-workflow update + limits. Signals have no equivalent limit, making them safer for + high-throughput publishing. + +**Why updates for poll:** + +- The caller needs a result (the items). Blocking is the desired behavior + (long-poll semantics). `wait_condition` inside an update handler is the + natural fit. + +**Why not updates for publish?** The main attraction would be platform-native +exactly-once via Update ID, eliminating application-level dedup. However: + +1. Update ID dedup does not persist across continue-as-new. For CAN workflows, + application-level dedup is required regardless + ([temporal/temporal#6375](https://github.com/temporalio/temporal/issues/6375)). +2. Each flush would block for a round-trip to the worker (~10-50ms), throttling + the publisher. +3. The 2x history cost accelerates approach to the CAN threshold. + +If the cross-CAN dedup gap is fixed and backpressure becomes desirable, +switching publish to updates is a mechanical change — the dedup protocol, +dedup protocol, and mixin handler logic are unchanged. + +## Design Principles + +### Deduplication follows the end-to-end principle + +**The end-to-end principle** (Saltzer, Reed, Clark, "End-to-End Arguments in +System Design," 1984): a function can be correctly and completely +implemented only with the knowledge available at the endpoints of a +communication system. Implementing it at intermediate layers may be +redundant or of little value, because the endpoints must handle it +regardless. The corollary: implement a function at the lowest layer that +can implement it *completely*. Don't partially implement it at an +intermediate layer. + +> 🚪 **One-way door.** The contract that the stream is an append-only log of +> *all* attempts — including failed ones — is irreversible once subscribers +> build reducers around it. Every frontend reducer expects to see interleaved +> retries and uses application-level events (e.g., `AGENT_START` resetting the +> text accumulator) to reconcile. If the transport later started filtering +> retries, existing reducers would break — they would miss the state +> transitions they depend on, and there would be two different behaviors +> depending on whether the subscriber was connected live (saw the failed +> attempt) or replayed after reconnect (didn't). This is the correct design, +> but it is a permanent commitment. + +**Our design decision.** We do not filter out events from failed activity +attempts. When an activity retries — for example, an LLM call that times +out, or a tool call that fails because a worker crashes — its previous +attempt's streaming events remain in the log. The new attempt publishes +fresh events. The subscriber sees both. + +**Why the pub/sub layer cannot handle this completely.** When an LLM +activity retries, the model runs again and produces different output — +different tokens, different wording, a different response. The pub/sub +layer sees two different message sequences. It has no way to know these +represent the same logical operation. Only the application knows that the +second response supersedes the first. + +We could have added retry semantics to the pub/sub protocol — for example, +tagging messages with attempt numbers and letting the transport filter +superseded attempts, similar to signal-level dedup. But this would be +incomplete, and the incompleteness creates a real problem: if the +transport scrubs failed-attempt events, but the subscriber already saw +them in real time (before the retry happened), the subscriber now has two +code paths — one for the live stream (which included the failed attempt) +and one for replay after reconnect (which doesn't). Two paths through the +frontend for the same logical scenario is a source of bugs and is +difficult to test. The transport's filtering doesn't save the subscriber +any work; the subscriber needs robust reconciliation logic regardless. + +**The contract: an append-only log of attempts.** The stream records what +happened, including failed attempts. The subscriber decides how to present +this to the user. In our frontend, the application-layer reducer handles +reconciliation: a new `TEXT_COMPLETE` event overwrites the previous one +(set semantics), and an `AGENT_START` event resets the text accumulator so +the retry's tokens replace the failed attempt's partial output. This +reducer produces the same state whether it processes events live or +replays them on reconnect — there is only one code path. + +**The pub/sub layer handles what it can handle completely.** Signal-level +dedup (same publisher ID + same sequence number) is fully resolvable at the +transport layer — the layer has all the information it needs, so it +deduplicates there. Activity-level dedup cannot be fully resolved at the +transport layer — it requires application context — so the pub/sub layer +does not attempt it. Each layer handles the duplicates it can completely +resolve. + +## Exactly-Once Publish Delivery + +External publishers get exactly-once delivery through publisher ID + sequence +number deduplication, following the Kafka producer model. + +### Problem + +`flush()` sends items via a Temporal signal. If the signal call raises after +the server accepted it (e.g., network timeout on the response), the client +cannot distinguish delivered from not-delivered. Without dedup, the client +must choose between at-most-once (data loss) and at-least-once (silent +duplication). + +### Solution + +Each `PubSubClient` instance generates a UUID (`publisher_id`) on creation. +Each `flush()` increments a monotonic `sequence` counter. The signal payload +includes both. The workflow tracks the highest seen sequence per publisher in +`_publisher_sequences: dict[str, int]` and rejects any signal with +`sequence <= last_seen`. + +``` +Client Workflow + │ │ + │ signal(publisher_id, seq=1) │ + │───────────────────────────────────►│ seq 1 > 0 → accept, record seq=1 + │ │ + │ signal(publisher_id, seq=1) │ (retry after timeout) + │───────────────────────────────────►│ seq 1 <= 1 → reject (duplicate) + │ │ + │ signal(publisher_id, seq=2) │ + │───────────────────────────────────►│ seq 2 > 1 → accept, record seq=2 +``` + +### Client-side flush + +```python +async def _flush(self) -> None: + async with self._flush_lock: + if self._pending is not None: + # Retry failed batch with same sequence + batch = self._pending + seq = self._pending_seq + elif self._buffer: + # New batch + seq = self._sequence + 1 + batch = self._buffer + self._buffer = [] + self._pending = batch + self._pending_seq = seq + else: + return + try: + await self._handle.signal( + "__temporal_pubsub_publish", + PublishInput(items=batch, publisher_id=self._publisher_id, + sequence=seq), + ) + self._sequence = seq # advance confirmed sequence + self._pending = None # clear pending + except Exception: + pass # pending stays for retry + raise +``` + +- **Separate pending from buffer**: failed batches stay in `_pending`, not + restored to `_buffer`. New `publish()` calls during retry go to the fresh + buffer. This prevents the data-loss bug where items would be merged into a + retry batch under a different sequence number. +- **Retry with same sequence**: on failure, the next `_flush()` retries the + same `_pending` with the same `_pending_seq`. If the signal was delivered + but the client saw an error, the workflow deduplicates the retry. +- **Sequence advances only on success**: `_sequence` (confirmed) is updated + only after the signal call returns without error. +- **Lock for coalescing**: concurrent `_flush()` callers queue on the lock. +- **max_retry_duration**: if set, the client gives up retrying after this + duration and raises `TimeoutError`. Must be less than the workflow's + `publisher_ttl` to preserve exactly-once guarantees. + +### Dedup state and TTL pruning + +`publisher_sequences` is `dict[str, int]` — bounded by number of publishers +(typically 1-2), not number of flushes. Carried through continue-as-new in +`PubSubState`. If `publisher_id` is empty (workflow-internal publish), +dedup is skipped. + +`publisher_last_seen` tracks the last `workflow.time()` each publisher was +seen. During `PubSub.get_state(publisher_ttl=timedelta(seconds=900))`, +entries older than TTL are pruned to bound memory across long-lived +workflow chains. + +**Safety constraint**: `publisher_ttl` must exceed the client's +`max_retry_duration`. If a publisher's dedup entry is pruned while it still +has a pending retry, the retry could be accepted as new, creating duplicates. + +### Scope: what pub/sub dedup does and does not handle + +Duplicates arise at three points in the pipeline. Each layer handles the +duplicates it introduces — applying the end-to-end principle (Saltzer, Reed, +Clark 1984). + +``` +LLM API --> Activity --> PubSubClient --> Workflow Log --> BFF/SSE --> Browser + (A) (B) (C) +``` + +| Type | Cause | Handled by | +|---|---|---| +| A: Duplicate LLM work | Activity retry produces a second, semantically equivalent but textually different response | Application layer (activity idempotency keys, workflow orchestration) | +| B: Duplicate signal batches | Signal retry after ambiguous failure delivers the same `(publisher_id, sequence)` batch twice | **Pub/sub layer** (`sequence <= last_seen` rejection) | +| C: Duplicate SSE events | Browser reconnects and BFF replays previously-delivered events | Delivery layer (SSE `Last-Event-ID`, idempotent frontend reducers) | + +**Why Type A doesn't belong here.** Data escapes to the subscriber during the +first LLM call — tokens are consumed, forwarded to the browser, and rendered +before any retry occurs. By the time a retry produces a duplicate response, +the original is already consumed. The pub/sub layer has no opportunity to +suppress it, and resolution requires application semantics (discard, replace, +merge) that the transport layer has no knowledge of. + +**Why Type B must be here.** The consumer sees `PubSubItem(topic, data)` with +no unique ID. If the workflow accepted a duplicate batch, the duplicates would +get fresh offsets and be indistinguishable from originals. Content-based dedup +has false positives (an LLM legitimately produces the same token twice; a +status event like `{"type":"THINKING_START"}` repeats across turns). The +`(publisher_id, sequence)` check is the only correct implementation — it +preserves transport encapsulation and uses context only the transport layer +has. + +**Why Type C doesn't belong here.** SSE reconnection is below the pub/sub +layer. The BFF assigns gapless event IDs and maps `Last-Event-ID` back to +global offsets (see [Information Leakage and the BFF](#information-leakage-and-the-bff)). + +## Continue-as-New + +### Problem + +The pub/sub mixin accumulates workflow history through signals (each +`__temporal_pubsub_publish`) and updates (each `__temporal_pubsub_poll` response). Over a +streaming session, history grows toward the ~50K event threshold. CAN resets +the history while carrying the canonical log copy forward. + +### State + +```python +@dataclass +class PubSubState: + log: list[PubSubItem] = field(default_factory=list) + base_offset: int = 0 + publisher_sequences: dict[str, int] = field(default_factory=dict) + publisher_last_seen: dict[str, float] = field(default_factory=dict) +``` + +`PubSub(prior_state=...)` restores all four fields. `PubSub.get_state()` +snapshots them. + +### Draining + +A long-poll `__temporal_pubsub_poll` blocks indefinitely until new data arrives. To +allow CAN to proceed, draining uses two mechanisms: + +1. **`PubSub.drain()`** sets a flag that unblocks all waiting poll handlers + (the `or self._draining` clause in `wait_condition`). +2. **Update validator** rejects new polls when draining, so no new handlers + start and `all_handlers_finished()` stabilizes. + +`PubSub.continue_as_new(build_args)` packages the three steps: + +```python +# CAN sequence in the parent workflow: +await self.pubsub.continue_as_new(lambda state: [WorkflowInput( + pubsub_state=state, +)]) +``` + +`build_args` runs *after* drain stabilizes, with the post-drain +`PubSubState` as its single argument. Workflows that need to override +other CAN parameters (`task_queue`, `retry_policy`, `run_timeout`, etc.) +fall back to the explicit recipe: + +```python +self.pubsub.drain() +await workflow.wait_condition(workflow.all_handlers_finished) +workflow.continue_as_new(args=[WorkflowInput( + pubsub_state=self.pubsub.get_state(), +)], task_queue="other-tq") +``` + +### Client-side CAN following + +`subscribe()` always follows CAN chains when the client was created via +`for_workflow()`. When a poll fails with +`WorkflowUpdateRPCTimeoutOrCancelledError`, the client calls `describe()` on +the handle. If the status is `CONTINUED_AS_NEW`, it gets a fresh handle for +the same workflow ID (targeting the latest run) and retries the poll from the +same offset. + +```python +async def _follow_continue_as_new(self) -> bool: + if self._client is None: + return False + try: + desc = await self._handle.describe() + except Exception: + return False + if desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW: + self._handle = self._client.get_workflow_handle(self._workflow_id) + return True + return False +``` + +The `describe()` check prevents infinite loops: if the workflow completed or +failed (not CAN), the subscriber stops instead of retrying. + +### Offset continuity + +Since the full log is carried forward: + +- Pre-CAN: offsets `0..N-1`, log length N. +- Post-CAN: `PubSub(prior_state=...)` restores N items. New appends start + at offset N. +- A subscriber at offset K resumes seamlessly against the new run. + +### Edge cases + +**Payload size limit.** The full log in CAN input could approach Temporal's +2 MB limit for very long sessions. Mitigation: truncation (discarding consumed +entries before CAN) is the natural extension, supported by `base_offset`. + +**Signal delivery during CAN.** A publisher sending mid-CAN may get errors if +its handle is pinned to the old run. The workflow should ensure activities +complete before triggering CAN. + +**Concurrent subscribers.** Each maintains its own offset. Sharing a +`PubSubClient` across concurrent `subscribe()` calls is safe. + +## Information Leakage and the BFF + +Global offsets leak cross-topic activity (a single-topic subscriber sees gaps). +This is acceptable within the pub/sub API because the subscriber is the BFF — +trusted server-side code. The leakage must not reach the end client (browser). + +### The problem + +If the BFF forwarded `PollResult.next_offset` to the browser (e.g., as an SSE +reconnection cursor), the browser could observe gaps and infer activity on +topics it is not subscribed to. Even if the offset is "opaque," a monotonic +integer with gaps is trivially inspectable. + +### Options considered + +We evaluated four approaches for browser-side reconnection: + +1. **BFF tracks the cursor server-side.** The BFF maintains a per-session + `session_id → last_offset` mapping. The browser reconnects with just the + session ID. On BFF restart, cursors are lost — fall back to replaying from + turn start. + +2. **Opaque token from the BFF.** The BFF wraps the global offset in an + encoded or encrypted token. The browser passes it back on reconnect. + `base64(offset)` is trivially reversible (security theater); real encryption + needs a key and adds a layer for marginal benefit over option 1. + +3. **BFF assigns SSE event IDs with `Last-Event-ID`.** The BFF emits SSE + events with `id: 1`, `id: 2`, `id: 3` (a BFF-local counter per stream). + On reconnect, the browser sends `Last-Event-ID` (built into the SSE spec). + The BFF maps that back to a global offset internally. + +4. **No mid-stream resume.** Browser reconnects, BFF replays from start of + the current turn. Frontend deduplicates. Simplest, but replays more data + than necessary. + +### Decision: SSE event IDs (option 3) + +The BFF assigns gapless integer IDs to SSE events and maintains a small +mapping from SSE event index to global offset. The browser never sees the +workflow's offset — it sees the BFF's event numbering. + +```python +sse_id = 0 +sse_id_to_offset: dict[int, int] = {} + +start_offset = await pubsub.get_offset() +async for item in pubsub.subscribe(topics=["events"], from_offset=start_offset): + sse_id += 1 + sse_id_to_offset[sse_id] = item_global_offset + yield f"id: {sse_id}\ndata: {item.data}\n\n" +``` + +On reconnect, the browser sends `Last-Event-ID: 47`. The BFF looks up the +corresponding global offset and resumes the subscription from there. + +The BFF is already per-session and stateful (it holds the SSE connection). +The `sse_id → global_offset` mapping is negligible additional state. On BFF +restart, the mapping is lost — fall back to replaying from turn start (option +4), which is acceptable because agent turns produce modest event volumes and +the frontend reducer is idempotent. + +This uses the SSE spec as designed: `Last-Event-ID` exists for exactly this +reconnection pattern. + +## Cross-Language Protocol + +Any Temporal client in any language can interact with a pub/sub workflow by: + +1. **Publishing**: Signal `__temporal_pubsub_publish` with `PublishInput` payload +2. **Subscribing**: Execute update `__temporal_pubsub_poll` with `PollInput`, loop +3. **Checking offset**: Query `__temporal_pubsub_offset` + +Double-underscore prefix on handler names avoids collisions with application +signals/updates. The envelope types are simple composites of strings, bytes, +and ints — representable in every Temporal SDK's default data converter. + +**Requires the default (JSON) data converter.** The wire protocol depends on +all participants — workflow, publishers, and subscribers — using the default +JSON data converter. A custom converter (protobuf, encryption codecs) would +change how the envelope types serialize, breaking cross-language interop. +This is also why payload data is opaque bytes: the pub/sub layer controls the +envelope format (guaranteed JSON-safe), while the application controls payload +serialization independently. + +## Compatibility + +> 🚪 **One-way door (two parts).** +> +> **Immutable handler names.** `__temporal_pubsub_publish`, `__temporal_pubsub_poll`, and +> `__temporal_pubsub_offset` are permanent wire-level entry points. The escape hatch — +> versioned handler names like `__temporal_pubsub_v2_poll` — gets more expensive over +> time: the mixin must register all supported versions, with no discovery +> mechanism for which versions a workflow supports. +> +> **No version field.** Committing to additive-only evolution means the *only* +> path for a true breaking change is versioned handler names. If the +> additive-only discipline ever fails — an existing field's semantics need to +> change, not just a new field added — there is no graceful migration path +> within a single handler. The argument against a version field is sound +> (signals are fire-and-forget, so version rejection equals silent data loss), +> but it means the protocol's evolvability hinges entirely on never needing to +> change existing field semantics. + +The wire protocol evolves under four rules to prevent accidental breakage by +future contributors. + +### Alternatives considered + +We evaluated and rejected five approaches to protocol evolution in favor of +additive-only. + +**Version field in payloads.** Add `version: int` to each wire type and have +the receiver check it. Fatal flaw: signals are fire-and-forget. If a v1 +workflow receives a v2 signal and rejects it based on version, the publisher +never learns the signal was rejected — silent data loss. Strictly worse than +the current behavior, where unknown fields are harmlessly dropped by +Temporal's JSON deserializer. For updates (poll), a version mismatch could +return an error, but this only helps if you change the semantics of an +existing field — which you should not do (that is a new handler, not a +version bump). + +**Versioned handler names** (e.g., `__temporal_pubsub_v2_poll`). The most robust +option — creates entirely separate protocol surfaces so old and new code +never interact. But premature: the mixin must register handlers for all +supported versions, the client must probe which versions exist (Temporal +has no "does this handler exist?" primitive), and dead code accumulates. +Reserved as the escape hatch for a future true breaking change. + +**Protocol negotiation.** Client declares version in poll, workflow +responds with what it supports. Turns the mixin into a version-dispatching +router. Disproportionate complexity. Temporal's Worker Versioning (Build ID +routing) solves this better at the infrastructure level — route tasks to +compatible workers rather than negotiating at the message level. + +**SDK version embedding.** Couples the protocol to the SDK release cadence. +SDK version 2.0 might change zero protocol fields; SDK version 1.7 might +change three. The version number becomes meaningless noise. + +**Accepting silent incompatibility.** Letting version drift just break +silently. Unacceptable for a durable-stream contract: a v2 subscriber +hitting a v1 workflow should see older fields default, not corrupt state. + +**Why additive-only works.** Every protocol change to date has followed +the same pattern: new field with a default that preserves pre-feature +behavior. This matches Protocol Buffers wire compatibility rules (never +change the meaning of an existing field number; always provide defaults +for new fields) and Avro's schema evolution model. Temporal's own +mechanisms cover the hard cases: + +- **Worker Versioning (Build IDs):** For true breaking changes, deploy v2 + mixin on a new Build ID. Old workflows continue on old workers; new + workflows start on new workers. Strictly more powerful than + message-level versioning because it operates at the workflow execution + level. +- **`workflow.patched()`:** For in-workflow behavior branching during + replay. Gates old vs. new logic within the same workflow code during + transition periods. + +**Ecosystem parallel.** Kafka's inter-broker protocol uses explicit version +numbers because brokers in a cluster must negotiate capabilities at +connection time — a fundamentally different topology from our +single-workflow-instance model. Our pattern is closer to protobuf wire +evolution: the schema is the contract, defaults handle absence, and +breaking changes get a new message type (handler name). + +### 1. Additive-only wire evolution + +New fields on `PublishInput`, `PollInput`, `PollResult`, and `PubSubState` must +have defaults. Existing field semantics must not change. Temporal's JSON data +converter drops unknown fields on deserialization and uses defaults for missing +fields, so additive changes are safe in both directions (new client → old +workflow, and vice versa). This is the same model as Protocol Buffers wire +compatibility. + +### 2. Handler names are immutable + +`__temporal_pubsub_publish`, `__temporal_pubsub_poll`, and `__temporal_pubsub_offset` will never change +meaning. If a future change is incompatible with additive evolution, the correct +mechanism is a new handler name (e.g., `__temporal_pubsub_v2_poll`) — creating an +entirely separate protocol surface so old and new code never interact. + +### 3. `PubSubState` must be forward-compatible + +New fields use `field(default_factory=...)` or scalar defaults. Old state loaded +into new code works (new fields get defaults). New state loaded into old code +works (unknown fields dropped by the JSON deserializer). This ensures seamless +continue-as-new across mixed-version deployments. + +### 4. No application-level version negotiation + +We do not add version fields to payloads, and we do not negotiate protocol +versions between client and workflow. The reasons: + +- **Signals cannot return errors.** A version field that the workflow checks on a + signal creates silent data loss: the workflow rejects the signal, but the + client (which used fire-and-forget delivery) never learns it was rejected. + This is strictly worse than the current behavior, where unknown fields are + harmlessly ignored. +- **Temporal Worker Versioning handles the hard cases.** For a true breaking + change, deploy the new mixin on a new Build ID. Old running workflows continue + on old workers; new workflows start on new workers. This operates at the + infrastructure level — handling in-flight workflows, replay, and mixed-version + fleets — which message-level version fields cannot. +- **`workflow.patched()` handles in-workflow transitions.** If a new mixin + version changes behavior (e.g., how it processes a signal), `patched()` gates + old vs. new logic within the same workflow code during the transition period. + +### Field defaults + +All fields follow rule 1: + +| Field | Default | Behavior when absent | +|---|---|---| +| `PublishInput.publisher_id` | `""` | Empty string skips dedup | +| `PublishInput.sequence` | `0` | Zero skips dedup | +| `_WireItem.offset` | `0` | Zero means "unknown" | +| `PollResult.more_ready` | `False` | No truncation signaled | +| `PubSubState.publisher_last_seen` | `{}` | No TTL pruning state | + +## Ecosystem analogs + +The closest analogs in established messaging systems, for orientation: + +- **Offset model** — NATS JetStream: one stream, multiple subjects, a + single monotonic sequence number. Subscribers filter by subject but + advance through the global sequence space. This is our model. +- **Idempotent producer** — Kafka's producer ID + monotonic sequence + number, scoped to the broker. Our `publisher_id` + `sequence` at the + workflow does the same job, scoped to signal delivery into one workflow. +- **Blocking pull** — Redis Streams `XREAD BLOCK`. Our `__temporal_pubsub_poll` + update with `wait_condition` is the Temporal-native equivalent. +- **Durable-execution peer** — the Workflow SDK ([workflow-sdk.dev](https://workflow-sdk.dev)) + has a first-class streaming model with indexed resumption and buffered + writes, but uses external storage (Redis/filesystem) as the broker + rather than the workflow itself. + +Full comparison tables (same/different with Kafka, NATS JetStream, Redis +Streams, and Workflow SDK) live on the +[Streaming API Design Considerations Notion page](https://www.notion.so/3478fc567738803d9c22eeb64a296e21). + +## Future Work + +### Shared workflow-side fan-out + +Each `__temporal_pubsub_poll` update today is serviced independently, and an item +published to N interested subscribers crosses the wire N times (see +[Design Decision 9](#9-subscription-is-poll-based-exposed-as-async-iterator)). +For low fan-out (1–2 consumers) this is fine; for workloads with many +concurrent subscribers on overlapping topics the duplication becomes the +dominant cost. + +A shared fan-out would keep a registry of active polls inside the +workflow, coalesce them by `(from_offset, topics)` key, and have one +poll wake-up build a shared response that the handler returns to every +matching caller. The tricky parts are: (a) offsets and topic filters +usually differ per subscriber, limiting coalescing; (b) the registry is +workflow state that must survive continue-as-new; (c) cancelled polls +must be reaped cleanly so the registry doesn't leak across replays. +Until a concrete workload shows the linear-in-subscribers cost matters, +the simpler per-poll model is the right default — applications that need +local fan-out can share one `subscribe()` iterator across N in-process +consumers on the client side, where state is trivial. + +### Workflow-defined filters and transforms + +Today the only filter is "topic in topics". A richer model would let +the workflow register named filters or transforms — e.g., `filter="high_priority"` +or `transform="redact_pii"` — that run inside the poll handler before +items are returned. This keeps computation close to the log, avoids +shipping items the subscriber will discard, and lets workflows enforce +access control per subscriber rather than delegating it to clients. + +Design questions left open: filter/transform registration API (at +`PubSub` construction, or later?), whether transforms may change the +item count (e.g., aggregation), how filter state interacts with +continue-as-new, and how filter identity is named on the wire for +cross-language clients. + +### Replace workflow-side dedup with server-side `request_id` + +Workflow-side `(publisher_id, sequence)` dedup +([Exactly-Once Publish Delivery](#exactly-once-publish-delivery)) +exists because Temporal's built-in signal `request_id` dedup does not +cover the cases the contrib needs: + +1. **Within a single `_flush()` call**, sdk-core's retry layer reuses the + same `request_id` across attempts, so the server already dedups + transient RPC failures. We get this for free. +2. **Across `_flush()` calls** (the `_pending` retry loop), each call + to `await handle.signal(...)` allocates a fresh `request_id` — + `temporalio/client.py:8357` hardcodes `request_id=str(uuid.uuid4())`, + with no way to override. The server cannot recognize that two such + calls are the same logical batch, so the workflow-side check is + what guarantees exactly-once. +3. **Across continue-as-new**, even if (1) and (2) were perfect, + `pendingSignalRequestedIDs` is per-run mutable state and is not + carried by `addWorkflowExecutionStartedEventForContinueAsNew`. A + retry whose first attempt landed on run N and whose retry lands on + run N+1 is accepted as fresh. Verified empirically on the Temporal + dev server and Temporal Cloud (see + `experiments/can-signal-dup/README.md` in the repo root for the + reproduction). [temporalio/temporal#4021](https://github.com/temporalio/temporal/issues/4021) + tracks the related state-growth concern that has historically + discouraged extending the dedup set across CAN. + +If both (a) the SDK exposes `request_id` on +`WorkflowHandle.signal()` and (b) the server dedups by `request_id` +across continue-as-new, the workflow-side check becomes redundant and +can be removed. The migration is mechanical because the dedup keys at +both layers are already aligned. + +**What changes:** + +```python +# In _client.py, _flush() — pin a deterministic request_id: +await self._handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=batch, + publisher_id=self._publisher_id, + sequence=seq, + ), + request_id=f"{self._publisher_id}:{seq}", # NEW +) +``` + +```python +# In _mixin.py, __temporal_pubsub_publish handler — drop the dedup branch: +def _pubsub_publish(self, input: PublishInput) -> None: + # remove: if input.publisher_id and input.sequence ... + self._log.extend(input.items) +``` + +```python +# In PubSubState — these fields become unused and can be removed in a +# follow-up wire migration (see Compatibility): +# publisher_sequences: dict[str, int] +# publisher_last_seen: dict[str, float] +``` + +**What stays:** + +- The client-side `_pending` retry loop and `_flush_lock`. Server-side + `request_id` dedup makes retries safe; it does not eliminate the + reasons we retry (long outages, worker restarts). +- The `(publisher_id, sequence)` shape on the wire. We continue to + send them — they are the inputs we'd derive `request_id` from, and + keeping them on the wire preserves observability and lets older + workflow versions that still maintain the dedup table interoperate + with newer clients during rollout. +- `force_flush=True`, `flush()`, `__aexit__` flush — orthogonal. + +**What goes away:** + +- `publisher_sequences` and `publisher_last_seen` in `PubSubState`. +- `publisher_ttl` and the `publisher_ttl > max_retry_duration` safety + constraint — there is no longer a per-publisher map to expire. +- The TLA+ retry-algorithm verification (`PubSubDedupTTL.tla`); the + on-workflow check it models has been removed. The + ordering/correctness specs that don't mention dedup still apply. + +**Migration path:** + +1. Land the SDK change to expose `request_id` on signals. +2. Confirm server `request_id` dedup spans CAN (re-run + `experiments/can-signal-dup` against a server build that includes + the fix). +3. Bump the contrib protocol minor version. Newer clients send the + pinned `request_id`; older clients still send fresh UUIDs. Both + continue to set `(publisher_id, sequence)` so a workflow that has + not yet been re-deployed remains correct. +4. After all clients are upgraded, deploy a workflow version that + ignores `(publisher_id, sequence)` and relies on the server. Drop + the dedup fields from `PubSubState` in a subsequent wire-format + pass once the old fields are no longer read by any deployed + version. + +Until both prerequisites are real, the workflow-side dedup is +load-bearing and must stay. + +### Workflow-side subscription + +[Design Decision 10](#10-workflow-can-publish-but-should-not-subscribe) +explains why workflow code shouldn't read the log directly today — the +log contains data from non-deterministic signal inputs, and branching on +it creates replay-sensitive code paths. There are workflow-side use +cases (aggregator workflows, workflows that fan events out to child +workflows, workflows that trigger activities based on stream content) +where a proper subscription API would be useful. + +A safe workflow-side `subscribe()` would need to tag reads so they go +through the same determinism machinery as other non-deterministic +inputs — likely surfaced as an async iterator that yields at +deterministic checkpoints. The simplest cut is probably a pull-based +iterator over `self._log` slices that integrates with `wait_condition` +for the "no data yet" case, mirroring the external poll API but +bypassing the update RPC layer. + +### Why `continue_as_new` takes a state-bound builder + +The helper's signature is +`continue_as_new(build_args: Callable[[PubSubState], Sequence[Any]])`. +Two earlier shapes were rejected: + +1. **Eager-args form** — + `continue_as_new(args=[WorkflowInput(pubsub_state=self.get_state(), ...)])`. + Python evaluates call-site arguments before the method body runs, + so `self.get_state()` would snapshot state *before* `drain()` and + `all_handlers_finished` — the opposite of the recipe's intent. +2. **Zero-arg builder** — `build_args: Callable[[], Sequence[Any]]`, + the lambda inspecting `self.pubsub` directly. Defers evaluation + correctly, but leaves the caller free to write + `self.pubsub.get_state()` inside the lambda, where the + "evaluated post-drain" contract is implicit and only documented in + prose. + +Passing the post-drain `PubSubState` as the lambda's parameter makes +the contract structural: there is one path to the state and the helper +controls when it's read. The signature itself reads as "here is the +state, return the CAN args." + +The helper deliberately does *not* mirror the full +`workflow.continue_as_new` signature (12 parameters today). Workflows +that need to override `task_queue` / `retry_policy` / `run_timeout` / +etc. fall back to the explicit `drain` / `wait_condition` / +`workflow.continue_as_new(...)` recipe — keeping the helper's surface +area stable as new CAN options land in `temporalio.workflow`. + +## File Layout + +``` +temporalio/contrib/pubsub/ +├── __init__.py # Public API exports +├── _broker.py # PubSub (workflow-side) +├── _client.py # PubSubClient (external-side) +├── _types.py # Shared data types +├── README.md # Usage documentation +└── DESIGN-v2.md # This document +``` diff --git a/temporalio/contrib/pubsub/README.md b/temporalio/contrib/pubsub/README.md new file mode 100644 index 000000000..b6c18c222 --- /dev/null +++ b/temporalio/contrib/pubsub/README.md @@ -0,0 +1,288 @@ +# Temporal Workflow Pub/Sub + +Workflows sometimes need to push incremental updates to external observers. +Examples include providing customer updates during order processing, creating +interactive experiences with AI agents, or reporting progress from a +long-running data pipeline. Temporal's core primitives (workflows, signals, and +updates) already provide the building blocks, but wiring up batching, offset +tracking, topic filtering, and continue-as-new hand-off is non-trivial. + +This module packages that boilerplate into a reusable broker and client. The +workflow acts as a message broker that maintains an append-only log. +Applications can interact directly from the workflow, or from external clients +such as activities, starters, and other workflows. Under the hood, publishing +uses signals (fire-and-forget) while subscribing uses updates (long-poll). A +configurable batching coalesces high-frequency events, improving efficiency. + +Payloads are Temporal `Payload`s carrying the encoding metadata needed for +typed decode (`subscribe(result_type=T)`) and heterogeneous-topic dispatch +(`Payload.metadata`). The codec chain (encryption, PII-redaction, +compression) runs once on the signal/update envelope that carries each +batch — **not** per item — so there is no double-encryption, and codec +behavior is symmetric between workflow-side and client-side publishing. + +## Quick Start + +### Workflow side + +Construct a `PubSub` from your `@workflow.init`. The constructor +dynamically registers the pub/sub signal, update, and query handlers on +the current workflow, and raises `RuntimeError` if called twice. If you +want the workflow to support continue-as-new, include a +`PubSubState | None` field on the input and pass it through — it's +`None` on fresh starts and carries state across CAN otherwise: + +```python +from dataclasses import dataclass +from temporalio import workflow +from temporalio.contrib.pubsub import PubSub, PubSubState + +@dataclass +class MyInput: + pubsub_state: PubSubState | None = None + +@workflow.defn +class MyWorkflow: + @workflow.init + def __init__(self, input: MyInput) -> None: + self.pubsub = PubSub(prior_state=input.pubsub_state) + + @workflow.run + async def run(self, input: MyInput) -> None: + self.pubsub.publish("status", StatusEvent(state="started")) + await do_work() + self.pubsub.publish("status", StatusEvent(state="done")) +``` + +Both workflow-side and client-side `publish()` use the sync payload +converter for per-item `Payload` construction. The codec chain runs +once at the envelope level (`__temporal_pubsub_publish` signal, +`__temporal_pubsub_poll` update) — never per item — so encryption, +PII-redaction, and compression are applied once each way. + +### Activity side (publishing) + +Use `PubSubClient.from_activity()` with the async context manager for +batched publishing. The Temporal client and target workflow ID are taken +from the activity context: + +```python +from datetime import timedelta + +from temporalio import activity +from temporalio.contrib.pubsub import PubSubClient + +@activity.defn +async def stream_events() -> None: + client = PubSubClient.from_activity(batch_interval=timedelta(seconds=2)) + async with client: + for chunk in generate_chunks(): + client.publish("events", chunk) + activity.heartbeat() + # Buffer is flushed automatically on context manager exit +``` + +Use `force_flush=True` to trigger an immediate flush for latency-sensitive events: + +```python +client.publish("events", data, force_flush=True) +``` + +### Subscribing + +Use `PubSubClient.create()` and the `subscribe()` async iterator: + +```python +from temporalio.contrib.pubsub import PubSubClient + +client = PubSubClient.create(temporal_client, workflow_id) +async for item in client.subscribe(["events"], result_type=MyEvent): + print(item.topic, item.data) + if is_done(item): + break +``` + +`item.data` is a `temporalio.api.common.v1.Payload` when no +`result_type` is given; passing `result_type=T` decodes each item to +`T` via the client's data converter (including the codec chain). + +## Topics + +Topics allow subscribers to receive a subset of the messages in the pub/sub system. +Subscribers can request a list of specific topics, or provide an empty list to receive +messages from all topics. Publishing to a topic implicitly creates it. + +## Continue-as-new + +Carry both your application state and pub/sub state across continue-as-new +boundaries: + +```python +from dataclasses import dataclass, field +from temporalio import workflow +from temporalio.contrib.pubsub import PubSub, PubSubState + +@dataclass +class AppState: + # Whatever your workflow needs to carry forward. + ... + +@dataclass +class WorkflowInput: + app_state: AppState = field(default_factory=AppState) + pubsub_state: PubSubState | None = None + +@workflow.defn +class MyWorkflow: + @workflow.init + def __init__(self, input: WorkflowInput) -> None: + self.app_state = input.app_state + self.pubsub = PubSub(prior_state=input.pubsub_state) + + @workflow.run + async def run(self, input: WorkflowInput) -> None: + # ... do work, updating self.app_state ... + + if workflow.info().is_continue_as_new_suggested(): + await self.pubsub.continue_as_new(lambda pubsub_state: [WorkflowInput( + app_state=self.app_state, + pubsub_state=pubsub_state, + )]) +``` + +`PubSub.continue_as_new(build_args)` drains waiting subscribers, +waits for in-flight handlers to finish, then calls +`workflow.continue_as_new` with `build_args(post_drain_state)`. The +lambda receives the post-drain `PubSubState` so the snapshot is +guaranteed to happen *after* drain. Subscribers created via +`PubSubClient.create()` or `PubSubClient.from_activity()` automatically +follow continue-as-new chains. + +Workflows that need to pass other CAN parameters (`task_queue`, +`retry_policy`, `run_timeout`, etc.) fall back to the explicit recipe: + +```python +self.pubsub.drain() +await workflow.wait_condition(workflow.all_handlers_finished) +workflow.continue_as_new(args=[WorkflowInput( + app_state=self.app_state, + pubsub_state=self.pubsub.get_state(), +)], task_queue="other-tq") +``` + +## Gotcha: sync handlers racing `__temporal_pubsub_publish` + +If you add a **custom synchronous** `@workflow.update` or +`@workflow.signal` handler that reads `PubSub` state, and an +external client calls `handle.signal("__temporal_pubsub_publish", ...)` +immediately followed by that handler, the handler may observe +pre-publish state when both land in the same workflow activation. +Root cause: `PubSub` installs `__temporal_pubsub_publish` *dynamically* from +`@workflow.init`, so in the first activation the signal is buffered +until after your class-level handler has already been scheduled. + +Two framings for when you need to care: + +- If your producer and your update caller are **independent + services** (the common shape for `PubSub`), the handler already + has to be robust to "update arrived before publish" for reasons + unrelated to this race — network timing, missing publishes, bad + offsets. Whatever policy you have for those covers this race too. +- If your code does **sequential same-client** ordering — await + `handle.signal(...)`, then await `handle.execute_update(...)` on + the same handle, and expect the signal's effects to be visible — + use the recipe below. + +### Recipe + +Make the handler `async` and yield once before touching `PubSub` +state: + +```python +import asyncio +from temporalio import workflow + +@workflow.defn +class MyWorkflow: + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + + @workflow.update + async def truncate_at(self, offset: int) -> None: + await asyncio.sleep(0) # let pending publishes apply + self.pubsub.truncate(offset) # now sees post-signal state +``` + +`asyncio.sleep(0)` is a pure asyncio-level yield — one event-loop +tick, no Temporal timer, no history events, no server round trip. +Do **not** substitute `workflow.sleep(0)`; that schedules a Temporal +timer and adds history events on every call. + +Already-safe patterns, no recipe needed: + +- The module's own `__temporal_pubsub_poll` update (it is already `async` and + `await`s `workflow.wait_condition` internally). +- Any `async` handler that `await`s something before reading + `PubSub` state. +- Handlers whose semantics are naturally "wait for the state I'm + asking about" — use `await workflow.wait_condition(lambda: ...)` + with a meaningful predicate instead of `asyncio.sleep(0)`. +- Workflow-internal publishes (`self.pubsub.publish(...)` from + `run()` or from an activity); these do not race. + +See `SIGNAL-UPDATE-RACE.md` in this directory for the full +activation-ordering mechanics. + +## API Reference + +### PubSub + +| Method | Description | +|---|---| +| `PubSub(prior_state=None)` | Constructor. Call once from `@workflow.init`; registers handlers on the current workflow. Raises `RuntimeError` if a `PubSub` is already registered. Pass `prior_state` if the input declares one (`None` on fresh starts). | +| `publish(topic, value)` | Append to the log from workflow code. `value` is converted via the sync workflow payload converter (no codec). | +| `get_state(*, publisher_ttl=timedelta(seconds=900))` | Snapshot for continue-as-new. Drops publisher dedup entries older than `publisher_ttl` (a `timedelta`, default 15 minutes). | +| `drain()` | Unblock polls and reject new ones. | +| `continue_as_new(build_args, *, publisher_ttl=timedelta(seconds=900))` | Async. Drain, wait for handlers, then `workflow.continue_as_new` with `build_args(post_drain_state)`. Use the explicit recipe to override other CAN parameters. | +| `truncate(up_to_offset)` | Discard log entries below the given offset. Workflow-side only — no external API; wire up your own signal or update if external control is needed. | + +Handlers registered by the constructor: + +| Kind | Name | Description | +|---|---|---| +| Signal | `__temporal_pubsub_publish` | Receive external publications. | +| Update | `__temporal_pubsub_poll` | Long-poll subscription. | +| Query | `__temporal_pubsub_offset` | Current global offset. | + +### PubSubClient + +| Method | Description | +|---|---| +| `PubSubClient.create(client, workflow_id, *, batch_interval, max_batch_size, max_retry_duration)` | Factory with an explicit Temporal client and workflow id. Follows CAN. | +| `PubSubClient.from_activity(*, batch_interval, max_batch_size, max_retry_duration)` | Factory that takes client and workflow id from the current activity context. Follows CAN. | +| `PubSubClient(handle, *, batch_interval, max_batch_size, max_retry_duration)` | From handle (no CAN follow). | +| `publish(topic, value, force_flush=False)` | Buffer a message. `value` may be any converter-compatible object or a pre-built `Payload`. Per-item conversion uses the sync payload converter; the codec chain runs once on the signal envelope. | +| `subscribe(topics, from_offset, *, result_type=None, poll_cooldown=timedelta(milliseconds=100))` | Async iterator. With `result_type=T`, `item.data` is decoded to `T`; otherwise it is a raw `Payload`. Follows CAN chains when created via `create` or `from_activity`. | +| `get_offset()` | Query current global offset. | + +Use as `async with` for batched publishing with automatic flush. + +## Cross-Language Protocol + +Any Temporal client can interact with a pub/sub workflow using these +fixed handler names: + +1. **Publish:** Signal `__temporal_pubsub_publish` with `PublishInput` +2. **Subscribe:** Update `__temporal_pubsub_poll` with `PollInput` -> `PollResult` +3. **Offset:** Query `__temporal_pubsub_offset` -> `int` + +The Python API exposes Temporal `Payload`s and decodes via the client's +data converter. On the wire, each `PublishEntry.data` / `_WireItem.data` +is a base64-encoded `Payload.SerializeToString()` so the transport +remains JSON-serializable while preserving `Payload.metadata` (used by +codecs and by the decode path). Cross-language clients can publish and +subscribe by following the same base64-of-serialized-`Payload` shape. +The signal/update envelopes (`PublishInput`, `PollResult`, `PubSubState`) +require the default (JSON) data converter; custom converters on the +envelope layer will break cross-language interop. diff --git a/temporalio/contrib/pubsub/SIGNAL-UPDATE-RACE.md b/temporalio/contrib/pubsub/SIGNAL-UPDATE-RACE.md new file mode 100644 index 000000000..6fb3450de --- /dev/null +++ b/temporalio/contrib/pubsub/SIGNAL-UPDATE-RACE.md @@ -0,0 +1,337 @@ +# Dynamic-signal vs. class-level-update race in `contrib/pubsub` + +**Status:** design note for team review. +**Context:** surfaced while stabilizing PR #1423 CI; one test +(`test_poll_truncated_offset_returns_application_error`) failed +deterministically under parallel load before being patched around. + +## TL;DR + +`PubSub` registers `__temporal_pubsub_publish` as a **dynamic** signal handler +(via `workflow.set_signal_handler` inside `PubSub.__init__`). Any +**class-level, synchronous** `@workflow.update` that reads `PubSub` +state and fires in the **same activation** as a just-arrived +`__temporal_pubsub_publish` signal will observe pre-signal state — zero items +in the log — and raise from the handler before the buffered signal +task gets a chance to run. + +The PR works around this by seeding log state from `@workflow.init` +in the test workflow. That keeps CI green but does not fix the race +for users who follow the pattern `handle.signal(...)` → synchronous +user update. We document the gotcha and publish a one-line recipe +(`await asyncio.sleep(0)` at the top of sync update handlers that +read `PubSub` state); see the Recommendation section. + +## How the race is triggered + +Consider a workflow using `PubSub` with a user-defined synchronous +update that reads the log: + +```python +@workflow.defn +class MyWorkflow: + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + + @workflow.update # class-level, synchronous + def truncate(self, offset: int) -> None: + self.pubsub.truncate(offset) + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: False) +``` + +And a client that publishes then immediately truncates: + +```python +handle = await client.start_workflow(MyWorkflow.run, ...) +await handle.signal("__temporal_pubsub_publish", PublishInput(items=[...5 items...])) +await handle.execute_update("truncate", 3) +``` + +Under parallel test load (or just bad luck on the server), all three +events — `InitializeWorkflow`, `SignalWorkflow(__temporal_pubsub_publish)`, +`DoUpdate(truncate)` — can arrive at the worker in a **single** +`WorkflowActivation`. + +### What the worker does with that activation + +From `temporalio/worker/_workflow_instance.py`: + +1. `activate()` groups jobs into buckets + (`_workflow_instance.py:440–455`): + - `job_sets[1]` = signals **and** updates + - `job_sets[2]` = initialize_workflow, activity resolutions, etc. + +2. Process `job_sets[1]` (signals + updates) **first** + (`_workflow_instance.py:461–466`): + - `_apply(Signal(__temporal_pubsub_publish))` → looks up `self._signals`. + `__temporal_pubsub_publish` is registered **dynamically inside + `PubSub.__init__`**, which has not run yet. No handler → signal + goes into `_buffered_signals` + (`_workflow_instance.py:1061–1063`). + - `_apply(Update(truncate))` → looks up `self._updates`. `truncate` + is a **class-level** `@workflow.update`, so it is present in + `self._updates` from the workflow instance context's `__init__` + (`self._updates = dict(self._defn.updates)` at + `_workflow_instance.py:316`). A task is created immediately + and scheduled via `loop.call_soon`, appending to `self._ready` + (`_apply_do_update` → `create_task` at + `_workflow_instance.py:721`). + +3. `_run_once` (`_workflow_instance.py:2478–2511`): + - Lazy-instantiate the workflow object + (`_workflow_instance.py:2485–2486`). This runs `__init__` + **synchronously**. `PubSub.__init__` calls + `workflow.set_signal_handler("__temporal_pubsub_publish", self._on_publish)`. + - `workflow_set_signal_handler` (`_workflow_instance.py:1401–1424`) + installs the handler **and immediately drains the buffer** — + dispatching each buffered signal job through + `_process_signal_job` + (`_workflow_instance.py:2415–2453`), which creates an `asyncio.Task`. + That task's first `__step` lands in `self._ready` — **after** the + update task already there. + - The event loop drains `self._ready` in FIFO order + (`_workflow_instance.py:2489–2493`): + - **Update task runs first.** `truncate(3)` sees `self._log == []`. + - **Signal task runs second.** `_on_publish` appends 5 items — + too late. + +4. Before this PR, the update handler raised `ValueError` on the + empty-log check. That is not an `ApplicationError`, so it fails + the **entire workflow task**, not just the update. Subsequent + `execute_update("__temporal_pubsub_poll", …)` then returns + `WorkflowNotReadyFailure` and the test aborts. + +### Why `__temporal_pubsub_poll` is not affected + +`_on_poll` is `async` and contains + +```python +await workflow.wait_condition( + lambda: len(self._log) > log_offset or self._draining, +) +``` + +Even if the poll task runs before the signal task, the first `await` +yields back to the loop, the buffered-signal task gets its turn, +`_log` gets populated, the condition unblocks, and poll returns +items. The race is invisible for async handlers that yield. + +## Who is affected + +A user workflow hits this iff **all** are true: +- The workflow uses `PubSub` (so `__temporal_pubsub_publish` is dynamic). +- The workflow defines a class-level `@workflow.update` or + `@workflow.signal` that reads `PubSub` state synchronously (no + `await`). +- A client issues `handle.signal("__temporal_pubsub_publish", …)` immediately + followed by a call to that sync update, and the server batches + init + signal + update into one activation. + +The module's own `__temporal_pubsub_poll` avoids it (async). Workflow-internal +publishes (`self.pubsub.publish(...)` from `run()` or an activity) +avoid it (no client-initiated signal race). The failure mode is a +narrow slice but very real: it reproduced deterministically under +`pytest -n auto` load in CI and locally. + +## Zooming out: this race is a subset of a broader concern + +In most real applications that use `PubSub`, the publisher and the +caller of any custom update/query are independent actors — a +producer service publishes; a control-plane client reads or mutates. +From the update handler's perspective, these scenarios are +indistinguishable: + +1. **SDK race.** Publish buffered in the same activation as the + update; signal handler not yet installed; update reads pre-signal + state. +2. **Network race.** Publish hasn't reached the server yet; update + arrives first. +3. **Genuinely early / out-of-range.** Publish is never coming; the + caller passed a bad offset. + +All three surface to the handler as "log is shorter than what the +caller asked about." Any handler that is robust to (2) and (3) — +which it must be, because those are inherent to distributed systems +— is automatically robust to (1). Whatever policy the handler picks +for "asked to act on state that isn't here yet" (error, wait with +timeout, no-op) covers the SDK race too. + +The case where "application robustness is enough" breaks down is +**sequential same-client ordering**: + +```python +await handle.signal("__temporal_pubsub_publish", items) # awaited +await handle.execute_update("custom_op", ...) # expects items visible +``` + +Here the caller completed the signal before issuing the update and +reasonably expects ordering to hold. The SDK race violates that +expectation. In practice, this single-client shape is rare in +`PubSub` use — the whole module shape is "one side writes, a +different side reads/mutates." Callers who *do* depend on sequential +ordering should use the recipe in the Recommendation section. + +## Options + +### 1. Do nothing (leave the PR's test-only workaround) + +**What:** keep `prepub_count` seeding in `TruncateWorkflow.__init__`. +Tests pass. Users with the affected pattern still hit the race. + +**Pros:** zero extra work, unblocks #1423. +**Cons:** silent footgun for users. Likely to resurface as a support +ticket. + +### 2. Document the caveat with a concrete recipe + +**What:** add a section to the `PubSub` docstring / contrib README +with the specific fix: + +> Custom synchronous `@workflow.update` or `@workflow.signal` +> handlers that read `PubSub` state seeded by `__temporal_pubsub_publish` +> may observe stale state when the external signal and the custom +> handler arrive in the same workflow activation. To close the +> window, make the handler `async` and yield once before touching +> `PubSub` state: +> +> ```python +> import asyncio +> +> @workflow.update +> async def my_update(self, ...) -> None: +> await asyncio.sleep(0) # let pending __temporal_pubsub_publish apply +> self.pubsub.truncate(...) # now sees post-signal state +> ``` +> +> `asyncio.sleep(0)` is a pure asyncio-level yield — no Temporal +> timer, no history events, no server round trip. Do not use +> `workflow.sleep(0)` (that *does* schedule a timer). +> +> Already-safe patterns: async handlers that `await` anything +> (including `workflow.wait_condition`); the module's own +> `__temporal_pubsub_poll`; any handler whose semantics already include +> "wait for the state I'm asking about" (use `wait_condition` on a +> meaningful predicate). + +**Pros:** honest; cheap; steers users toward a concrete, correct +pattern. Recipe matches what the SDK-level fix would do implicitly. +**Cons:** still a sharp edge — relies on users reading. See the +"Zooming out" section above: most applications have to be robust to +the same out-of-order arrival for reasons unrelated to this race, +so the recipe is only needed when users rely on strict sequential +same-client ordering. + +### 3. Make `__temporal_pubsub_publish` class-level (revert to a mixin) + +**What:** undo 72d296ea — expose `PubSubMixin` with +`@workflow.signal def __temporal_pubsub_publish(...)`. Users opt in by +inheritance. Class-level signals are present in `self._signals` from +instance-context construction, so `_apply(Signal)` schedules a +**signal** task, not buffers, and FIFO dispatch runs signal before +update. + +**Pros:** fully fixes the race at the library layer with no SDK +change. Zero user-visible footgun. +**Cons:** reintroduces all the reasons we moved to dynamic: +multiple-inheritance conflicts, users forgetting to inherit, +awkward composition with other mixins, forced class hierarchy. +We already rejected this. + +### 4. Fix the dispatch order in the SDK + +**What:** in `workflow_set_signal_handler` (or in `_run_once`), +arrange for buffered-signal tasks to be dispatched **ahead of** any +update tasks already queued from `_apply(job_sets[1])`. Concretely, +either: + +- Run buffered signal handlers synchronously (no `create_task`) when + drained from the buffer during `set_signal_handler`, so their state + mutations land before any task in `self._ready` runs; or +- Swap the grouping in `activate()` so `initialize_workflow` is + applied before signals+updates — so `PubSub.__init__` runs, the + signal handler is live at `_apply(Signal)` time, and the signal + task is created before the update task. + +**Pros:** real fix. Benefits every dynamic-signal user, not just +`PubSub`. Preserves current PubSub API. +**Cons:** non-trivial SDK change with broader blast radius. +Needs design review, wider test coverage (queries, continue-as-new, +updates with validators, async signals…). Not something we ship +alongside this PR. + +### 5. Make `PubSub.truncate` require an async context / add a publisher barrier + +**What:** explicitly disallow sync updates reading `PubSub` state by +making the read-path APIs async — e.g., `await pubsub.truncate(...)` +that internally `wait_condition`s on a "signal handler at least N +times" barrier. Or expose a `await pubsub.wait_for_publish_applied()` +primitive users call at the top of sync updates (which makes them +no longer sync, defeating the purpose). + +**Pros:** race-safe if users follow the API. +**Cons:** leaky — pushes SDK-activation-ordering concerns into the +user API. Compromises ergonomics of what should be a simple +in-memory mutation. + +## Recommendation + +Ship **(1) + (2)** now. Treat **(4)** as optional follow-up, not a +blocker. + +- Keep the `prepub_count` change in the test (it is legitimate test + scaffolding and avoids baking SDK-ordering assumptions into the + test surface). +- Add the caveat + `asyncio.sleep(0)` recipe from option (2) to the + contrib README as a visible "Gotcha" section, not a footnote, with + a link to this document for the full mechanics. +- Optionally file an issue against sdk-python for option (4). It is + a principled fix (dispatch buffered signals ahead of updates on + the same activation, or reorder the job-set buckets) but given + the "Zooming out" analysis, the payoff is narrow: it only helps + users who rely on sequential same-client publish→update ordering, + which is an uncommon pattern for `PubSub`. + +Rationale: +- Applications using `PubSub` with independent producers and + consumers must already handle "update arrives before publish" as + a general concern — the SDK race is a narrow special case covered + by that same robustness. +- (3) reverses a deliberate API decision we already made. +- (4) is correct but is a core-sdk-behavior change that deserves its + own PR, reviewers, and wider-test coverage (queries, + continue-as-new, validators, async signals…); the benefit is + limited to the sequential-same-client case. +- (5) bleeds SDK internals into user API. +- (1) alone is not enough — we need (2) so the escape hatch is + discoverable by users who do depend on sequential ordering. + +## Appendix: Minimal repro (already in the test file, pre-patch) + +```python +@workflow.defn +class TruncateWorkflow: + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + + @workflow.update + def truncate(self, offset: int) -> None: # sync + self.pubsub.truncate(offset) + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: False) + +# client +handle = await client.start_workflow(TruncateWorkflow.run, ...) +await handle.signal("__temporal_pubsub_publish", PublishInput(items=[...5 items...])) +await handle.execute_update("truncate", 3) # racy +``` + +Under `pytest -n auto --dist=worksteal` the update reliably observes +`len(self._log) == 0` and fails the workflow task. Running the test +in isolation passes every time. diff --git a/temporalio/contrib/pubsub/__init__.py b/temporalio/contrib/pubsub/__init__.py new file mode 100644 index 000000000..b124ec960 --- /dev/null +++ b/temporalio/contrib/pubsub/__init__.py @@ -0,0 +1,35 @@ +"""Pub/sub support for Temporal workflows. + +This module provides a reusable pub/sub pattern where a workflow acts as a +message broker. External clients (activities, starters, other services) publish +and subscribe through the workflow handle using Temporal primitives. + +Payloads are Temporal ``Payload`` values. Publishing values are +converted to ``Payload`` per item by the client's payload converter; +the codec chain (encryption, PII-redaction, compression) runs once on +the surrounding signal/update envelope rather than per item. Subscribers +yield raw ``Payload`` by default, or decode each item to a concrete +type via ``subscribe(result_type=T)``. +""" + +from temporalio.contrib.pubsub._broker import PubSub +from temporalio.contrib.pubsub._client import PubSubClient +from temporalio.contrib.pubsub._types import ( + PollInput, + PollResult, + PublishEntry, + PublishInput, + PubSubItem, + PubSubState, +) + +__all__ = [ + "PollInput", + "PollResult", + "PubSub", + "PubSubClient", + "PubSubItem", + "PubSubState", + "PublishEntry", + "PublishInput", +] diff --git a/temporalio/contrib/pubsub/_broker.py b/temporalio/contrib/pubsub/_broker.py new file mode 100644 index 000000000..fa70f9aae --- /dev/null +++ b/temporalio/contrib/pubsub/_broker.py @@ -0,0 +1,383 @@ +"""Workflow-side pub/sub broker. + +Instantiate :class:`PubSub` once from your workflow's ``@workflow.init`` +method. The constructor registers the pub/sub signal, update, and query +handlers on the current workflow via +:func:`temporalio.workflow.set_signal_handler`, +:func:`temporalio.workflow.set_update_handler`, and +:func:`temporalio.workflow.set_query_handler`. + +For workflows that support continue-as-new, include a +``PubSubState | None`` field on the workflow input and pass it as +``prior_state`` — it is ``None`` on fresh starts and carries accumulated +state on continue-as-new. + +Both workflow-side :meth:`PubSub.publish` and client-side +:meth:`PubSubClient.publish` use the synchronous payload converter for +per-item ``Payload`` construction. The codec chain (encryption, +PII-redaction, compression) is **not** run per item on either side — +it runs once at the envelope level when Temporal's SDK encodes the +signal/update that carries the batch. Running it per item as well +would double-encrypt, because every signal arg already goes through +the client's ``DataConverter.encode`` at dispatch time. +""" + +from __future__ import annotations + +import sys +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, Callable, NoReturn + +from temporalio import workflow +from temporalio.api.common.v1 import Payload +from temporalio.exceptions import ApplicationError + +from ._types import ( + PollInput, + PollResult, + PublishInput, + PubSubItem, + PubSubState, + _decode_payload, + _encode_payload, + _WireItem, +) + +_PUBLISH_SIGNAL = "__temporal_pubsub_publish" +_POLL_UPDATE = "__temporal_pubsub_poll" +_OFFSET_QUERY = "__temporal_pubsub_offset" + +_MAX_POLL_RESPONSE_BYTES = 1_000_000 + + +def _payload_wire_size(payload: Payload, topic: str) -> int: + """Approximate poll-response contribution of a single item. + + Wire form is ``_WireItem(topic, base64(proto(Payload)), offset)``. + Base64 inflates by ~4/3; we use the exact serialized length as a + close-enough proxy. + """ + return (payload.ByteSize() * 4 + 2) // 3 + len(topic) + + +class PubSub: + """Workflow-side pub/sub broker. + + Construct once from ``@workflow.init``; the constructor registers + the pub/sub signal, update, and query handlers on the current + workflow. Raises :class:`RuntimeError` if a ``PubSub`` has already + been registered on the workflow. + + Registered handlers: + + - ``__temporal_pubsub_publish`` signal — external publish with dedup + - ``__temporal_pubsub_poll`` update — long-poll subscription + - ``__temporal_pubsub_offset`` query — current log length + + Note: + Because ``__temporal_pubsub_publish`` is registered *dynamically* from + ``__init__``, custom **synchronous** update/signal handlers + that read ``PubSub`` state can observe pre-publish state when + both land in the same activation. Make such handlers ``async`` + and ``await asyncio.sleep(0)`` before reading state. See the + "Gotcha" section of this module's ``README.md`` for the + full explanation and recipe. + """ + + def __init__(self, prior_state: PubSubState | None = None) -> None: + """Initialize pub/sub state and register workflow handlers. + + Must be called directly from the workflow's ``@workflow.init`` + method. Calls made from ``@workflow.run``, helper methods, or + signal/update/query handlers raise :class:`RuntimeError`. + + The check inspects the immediate caller's frame and requires the + function name to be ``__init__``. A history-length check (expect + length 3 on the first workflow task) is not used because + pre-start signals inflate the first-task history and cache + evictions legitimately re-run ``__init__`` from later tasks. + + Args: + prior_state: State carried from a previous run via + :meth:`get_state` through continue-as-new, or ``None`` + on first start. + + Raises: + RuntimeError: If not called directly from a method named + ``__init__``, or if the pub/sub signal handler is + already registered on this workflow (i.e., ``PubSub`` + was instantiated twice). + + Note: + When carrying state across continue-as-new, type the + carrying field as ``PubSubState | None`` — not ``Any``. The + default data converter deserializes ``Any`` fields as plain + dicts, which silently strips the ``PubSubState`` type and + breaks the new run. + """ + caller = sys._getframe(1) + caller_name = caller.f_code.co_name + if caller_name != "__init__": + raise RuntimeError( + "PubSub must be constructed directly from the workflow's " + f"@workflow.init method, not from {caller_name!r}." + ) + if workflow.get_signal_handler(_PUBLISH_SIGNAL) is not None: + raise RuntimeError( + "PubSub is already registered on this workflow. " + "Construct PubSub(...) at most once from @workflow.init." + ) + + if prior_state is not None: + self._log: list[PubSubItem] = [ + PubSubItem(topic=item.topic, data=_decode_payload(item.data)) + for item in prior_state.log + ] + self._base_offset: int = prior_state.base_offset + self._publisher_sequences: dict[str, int] = dict( + prior_state.publisher_sequences + ) + self._publisher_last_seen: dict[str, float] = dict( + prior_state.publisher_last_seen + ) + else: + self._log = [] + self._base_offset = 0 + self._publisher_sequences = {} + self._publisher_last_seen = {} + self._draining: bool = False + + workflow.set_signal_handler(_PUBLISH_SIGNAL, self._on_publish) + workflow.set_update_handler( + _POLL_UPDATE, self._on_poll, validator=self._validate_poll + ) + workflow.set_query_handler(_OFFSET_QUERY, self._on_offset) + + def publish(self, topic: str, value: Any) -> None: + """Publish an item from within workflow code. + + ``value`` may be any Python value the workflow's payload + converter can handle, or a pre-built + :class:`temporalio.api.common.v1.Payload` for zero-copy. + + The codec chain is not applied here (it runs on the + ``__temporal_pubsub_poll`` update envelope that later delivers the + item to a subscriber). + """ + if isinstance(value, Payload): + payload = value + else: + payload = workflow.payload_converter().to_payloads([value])[0] + self._log.append(PubSubItem(topic=topic, data=payload)) + + def get_state( + self, *, publisher_ttl: timedelta = timedelta(seconds=900) + ) -> PubSubState: + """Return a serializable snapshot of pub/sub state for continue-as-new. + + Prunes publisher dedup entries older than ``publisher_ttl``. The + TTL must exceed the ``max_retry_duration`` of any client that + may still be retrying a failed flush. + + Args: + publisher_ttl: Duration after which a publisher's dedup + entry is pruned. Default 15 minutes. + """ + now = workflow.time() + ttl_secs = publisher_ttl.total_seconds() + + active_sequences: dict[str, int] = {} + active_last_seen: dict[str, float] = {} + for pid, seq in self._publisher_sequences.items(): + ts = self._publisher_last_seen.get(pid, 0.0) + if now - ts < ttl_secs: + active_sequences[pid] = seq + active_last_seen[pid] = ts + + return PubSubState( + log=[ + _WireItem(topic=item.topic, data=_encode_payload(item.data)) + for item in self._log + ], + base_offset=self._base_offset, + publisher_sequences=active_sequences, + publisher_last_seen=active_last_seen, + ) + + def drain(self) -> None: + """Unblock all waiting poll handlers and reject new polls. + + Call this before + ``await workflow.wait_condition(workflow.all_handlers_finished)`` + and ``workflow.continue_as_new()``. + """ + self._draining = True + + async def continue_as_new( + self, + build_args: Callable[[PubSubState], Sequence[Any]], + *, + publisher_ttl: timedelta = timedelta(seconds=900), + ) -> NoReturn: + """Drain, wait for handlers, then continue-as-new with built args. + + Replaces the three-line recipe ``drain()`` → + ``wait_condition(all_handlers_finished)`` → + ``workflow.continue_as_new(args=...)`` for the common case where + the only CAN parameter that varies is ``args``. + + ``build_args`` is invoked *after* drain has stabilized, with the + post-drain :class:`PubSubState` as its single argument. The + caller threads that state into whatever input dataclass the + workflow expects: + + .. code-block:: python + + await self.pubsub.continue_as_new(lambda state: [WorkflowInput( + items_processed=self.items_processed, + pubsub_state=state, + )]) + + Workflows that need to override other CAN parameters + (``task_queue``, ``retry_policy``, ``run_timeout``, etc.) should + keep using the explicit ``drain`` / ``wait_condition`` / + ``workflow.continue_as_new(...)`` recipe. + + Args: + build_args: Callable that receives the post-drain pub/sub + state and returns the positional ``args`` for the new + run. + publisher_ttl: Forwarded to :meth:`get_state`. + + Does not return; ``workflow.continue_as_new`` raises an internal + exception that the SDK uses to close the run. + """ + self.drain() + await workflow.wait_condition(workflow.all_handlers_finished) + workflow.continue_as_new( + args=build_args(self.get_state(publisher_ttl=publisher_ttl)), + ) + + def truncate(self, up_to_offset: int) -> None: + """Discard log entries before ``up_to_offset``. + + After truncation, polls requesting an offset before the new + base will receive an ApplicationError. All global offsets + remain monotonic. + + Raises ApplicationError (not ValueError) when ``up_to_offset`` + is past the end of the log so that callers invoking this from + an update handler surface it as an update failure rather than + a workflow-task poison pill. + + Args: + up_to_offset: The global offset to truncate up to + (exclusive). Entries at offsets + ``[base_offset, up_to_offset)`` are discarded. + """ + log_index = up_to_offset - self._base_offset + if log_index <= 0: + return + if log_index > len(self._log): + raise ApplicationError( + f"Cannot truncate to offset {up_to_offset}: only " + f"{self._base_offset + len(self._log)} items exist", + type="TruncateOutOfRange", + non_retryable=True, + ) + self._log = self._log[log_index:] + self._base_offset = up_to_offset + + def _on_publish(self, payload: PublishInput) -> None: + """Receive publications from external clients (activities, starters). + + Deduplicates using (publisher_id, sequence). If publisher_id is + set and the sequence is <= the last seen sequence for that + publisher, the entire batch is dropped as a duplicate. Batches + are atomic: the dedup decision applies to the whole batch, not + individual items. + + This block is a polyfill for missing server-side ``request_id`` + dedup across continue-as-new. If the SDK ever exposes + ``request_id`` on signals and the server dedups it across CAN, + this branch and the ``_publisher_sequences`` / + ``_publisher_last_seen`` state become redundant. See DESIGN-v2 + §"Replace workflow-side dedup with server-side request_id" for + the migration plan. + """ + if payload.publisher_id: + last_seq = self._publisher_sequences.get(payload.publisher_id, 0) + if payload.sequence <= last_seq: + return + self._publisher_sequences[payload.publisher_id] = payload.sequence + self._publisher_last_seen[payload.publisher_id] = workflow.time() + for entry in payload.items: + self._log.append( + PubSubItem(topic=entry.topic, data=_decode_payload(entry.data)) + ) + + async def _on_poll(self, payload: PollInput) -> PollResult: + """Long-poll: block until new items available or draining, then return.""" + log_offset = payload.from_offset - self._base_offset + if log_offset < 0: + if payload.from_offset == 0: + # "From the beginning" — start at whatever is available. + log_offset = 0 + else: + # Subscriber had a specific position that's been + # truncated. ApplicationError fails this update (client + # gets the error) without crashing the workflow task — + # avoids a poison pill during replay. + raise ApplicationError( + f"Requested offset {payload.from_offset} has been truncated. " + f"Current base offset is {self._base_offset}.", + type="TruncatedOffset", + non_retryable=True, + ) + await workflow.wait_condition( + lambda: len(self._log) > log_offset or self._draining, + ) + all_new = self._log[log_offset:] + if payload.topics: + topic_set = set(payload.topics) + candidates = [ + (self._base_offset + log_offset + i, item) + for i, item in enumerate(all_new) + if item.topic in topic_set + ] + else: + candidates = [ + (self._base_offset + log_offset + i, item) + for i, item in enumerate(all_new) + ] + # Cap response size to ~1MB wire bytes. + wire_items: list[_WireItem] = [] + size = 0 + more_ready = False + next_offset = self._base_offset + len(self._log) + for off, item in candidates: + item_size = _payload_wire_size(item.data, item.topic) + if size + item_size > _MAX_POLL_RESPONSE_BYTES and wire_items: + # Resume from this item on the next poll. + next_offset = off + more_ready = True + break + size += item_size + wire_items.append( + _WireItem(topic=item.topic, data=_encode_payload(item.data), offset=off) + ) + return PollResult( + items=wire_items, + next_offset=next_offset, + more_ready=more_ready, + ) + + def _validate_poll(self, _payload: PollInput) -> None: + """Reject new polls when draining for continue-as-new.""" + if self._draining: + raise RuntimeError("Workflow is draining for continue-as-new") + + def _on_offset(self) -> int: + """Return the current global offset (base_offset + log length).""" + return self._base_offset + len(self._log) diff --git a/temporalio/contrib/pubsub/_client.py b/temporalio/contrib/pubsub/_client.py new file mode 100644 index 000000000..10dde57e5 --- /dev/null +++ b/temporalio/contrib/pubsub/_client.py @@ -0,0 +1,473 @@ +"""External-side pub/sub client. + +Used by activities, starters, and any code with a workflow handle to +publish messages and subscribe to topics on a pub/sub workflow. + +Each published value is turned into a :class:`Payload` via the client's +sync payload converter. The **codec chain** (encryption, PII-redaction, +compression) is **not** run per item — it runs once at the envelope +level when Temporal's SDK encodes the ``__temporal_pubsub_publish`` signal args +and the ``__temporal_pubsub_poll`` update result. Running the codec per item as +well would double-encrypt / double-compress, because the envelope path +covers the items again. The per-item ``Payload`` still carries the +encoding metadata (``encoding: json/plain``, ``messageType``, etc.) +required by ``subscribe(result_type=T)`` on the consumer side. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from collections.abc import AsyncIterator +from datetime import timedelta +from typing import Any + +from typing_extensions import Self + +from temporalio import activity +from temporalio.api.common.v1 import Payload +from temporalio.client import ( + Client, + WorkflowExecutionStatus, + WorkflowHandle, + WorkflowUpdateFailedError, + WorkflowUpdateRPCTimeoutOrCancelledError, +) +from temporalio.converter import DataConverter, PayloadConverter + +from ._types import ( + PollInput, + PollResult, + PublishEntry, + PublishInput, + PubSubItem, + _decode_payload, + _encode_payload, +) + + +class PubSubClient: + """Client for publishing to and subscribing from a pub/sub workflow. + + Create via :py:meth:`create` (explicit client + workflow id), + :py:meth:`from_activity` (infer both from the current activity + context), or by passing a handle directly to the constructor. + + For publishing, use as an async context manager to get automatic + batching:: + + client = PubSubClient.create(temporal_client, workflow_id) + async with client: + client.publish("events", my_event) + client.publish("events", another_event, force_flush=True) + # Optional synchronization point — wait until everything + # buffered so far has been confirmed by the server. + await client.flush() + + For subscribing:: + + client = PubSubClient.create(temporal_client, workflow_id) + async for item in client.subscribe(["events"], result_type=MyEvent): + process(item.data) + """ + + def __init__( + self, + handle: WorkflowHandle[Any, Any], + *, + client: Client | None = None, + batch_interval: timedelta = timedelta(seconds=2), + max_batch_size: int | None = None, + max_retry_duration: timedelta = timedelta(seconds=600), + ) -> None: + """Create a pub/sub client from a workflow handle. + + Prefer :py:meth:`create` — it enables continue-as-new following + in ``subscribe()`` and supplies the :class:`Client` needed to + reach the data converter chain. + + Args: + handle: Workflow handle to the pub/sub workflow. + client: Temporal client whose payload converter will be used + to turn published values into ``Payload`` objects and to + decode subscriptions when ``result_type`` is set. The + codec chain is **not** applied per item (doing so would + double-encrypt — see module docstring). If ``None``, the + default payload converter is used. + batch_interval: Interval between automatic flushes. + max_batch_size: Auto-flush when buffer reaches this size. + max_retry_duration: Maximum time to retry a failed flush + before raising TimeoutError. Must be less than the + workflow's ``publisher_ttl`` (default 15 minutes) to + preserve exactly-once delivery. Default: 10 minutes. + """ + self._handle: WorkflowHandle[Any, Any] = handle + self._client: Client | None = client + self._workflow_id = handle.id + self._batch_interval = batch_interval + self._max_batch_size = max_batch_size + self._max_retry_duration = max_retry_duration + self._buffer: list[tuple[str, Any]] = [] + self._flush_event = asyncio.Event() + self._flush_task: asyncio.Task[None] | None = None + self._flush_lock = asyncio.Lock() + self._publisher_id: str = uuid.uuid4().hex[:16] + self._sequence: int = 0 + self._pending: list[PublishEntry] | None = None + self._pending_seq: int = 0 + self._pending_since: float | None = None + + @classmethod + def create( + cls, + client: Client, + workflow_id: str, + *, + batch_interval: timedelta = timedelta(seconds=2), + max_batch_size: int | None = None, + max_retry_duration: timedelta = timedelta(seconds=600), + ) -> PubSubClient: + """Create a pub/sub client from a Temporal client and workflow ID. + + Use this when the caller has an explicit ``Client`` and + ``workflow_id`` in hand (starters, BFFs, other workflows' + activities). For code running inside an activity that targets + its own parent workflow, see :py:meth:`from_activity`. + + A client created through this method follows continue-as-new + chains in ``subscribe()`` and uses the client's payload + converter for per-item ``Payload`` construction. + + Args: + client: Temporal client. + workflow_id: ID of the pub/sub workflow. + batch_interval: Interval between automatic flushes. + max_batch_size: Auto-flush when buffer reaches this size. + max_retry_duration: Maximum time to retry a failed flush + before raising TimeoutError. Default: 10 minutes. + """ + handle = client.get_workflow_handle(workflow_id) + return cls( + handle, + client=client, + batch_interval=batch_interval, + max_batch_size=max_batch_size, + max_retry_duration=max_retry_duration, + ) + + @classmethod + def from_activity( + cls, + *, + batch_interval: timedelta = timedelta(seconds=2), + max_batch_size: int | None = None, + max_retry_duration: timedelta = timedelta(seconds=600), + ) -> PubSubClient: + """Create a pub/sub client targeting the current activity's parent workflow. + + Must be called from within an activity. The Temporal client and + parent workflow id are taken from the activity context. + + Args: + batch_interval: Interval between automatic flushes. + max_batch_size: Auto-flush when buffer reaches this size. + max_retry_duration: Maximum time to retry a failed flush + before raising TimeoutError. Default: 10 minutes. + """ + info = activity.info() + workflow_id = info.workflow_id + if workflow_id is None: + raise RuntimeError( + "from_activity requires an activity with a parent workflow" + ) + return cls.create( + activity.client(), + workflow_id, + batch_interval=batch_interval, + max_batch_size=max_batch_size, + max_retry_duration=max_retry_duration, + ) + + async def __aenter__(self) -> Self: + """Start the background flusher task.""" + self._flush_task = asyncio.create_task(self._run_flusher()) + return self + + async def __aexit__(self, *_exc: object) -> None: + """Stop the flusher and flush any remaining buffered entries.""" + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + # Drain both pending and buffer. A single _flush() processes + # either pending OR buffer, not both — so if the flusher was + # cancelled mid-signal (pending set) while the producer added + # more items (buffer non-empty), a single final flush would + # orphan the buffer. + while self._pending is not None or self._buffer: + await self._flush() + + def publish(self, topic: str, value: Any, force_flush: bool = False) -> None: + """Buffer a message for publishing. + + ``value`` may be any Python value the client's payload + converter can handle, or a pre-built + :class:`temporalio.api.common.v1.Payload` for zero-copy. The + codec chain is not applied per item — it runs once on the + signal envelope that delivers the batch. + + Args: + topic: Topic string. + value: Value to publish. Converted to a ``Payload`` via + the client's sync payload converter at flush time. + Pre-built ``Payload`` instances bypass conversion. + force_flush: If True, wake the flusher to send immediately + (fire-and-forget — does not block the caller). + """ + self._buffer.append((topic, value)) + if force_flush or ( + self._max_batch_size is not None + and len(self._buffer) >= self._max_batch_size + ): + self._flush_event.set() + + async def flush(self) -> None: + """Flush buffered (and pending) items and wait for server confirmation. + + Returns once the items buffered at call time have been signaled to + the workflow and acknowledged by the server. Returns immediately + if there is nothing to send. + + This is in addition to the declarative ``force_flush=True`` on + :py:meth:`publish` and to the automatic flush on context-manager + exit. Use this when you need a synchronization point — proof + that prior publications have reached the server — at a moment + that does not naturally correspond to a specific event. + + Safe to call concurrently with ``publish()`` and with the + background flusher: the flush lock serializes signal sends. + Items added concurrently after entry may piggyback on this + flush or be deferred to a subsequent one. + + Raises: + TimeoutError: If a pending batch from a prior failure cannot + be sent within ``max_retry_duration``. The pending batch + is dropped; subsequent publications use a fresh sequence. + """ + while self._pending is not None or self._buffer: + await self._flush() + + def _payload_converter(self) -> PayloadConverter: + """Return the sync payload converter for per-item encode/decode. + + Uses the configured client's payload converter when available; + otherwise falls back to the default. The codec chain + (encryption, compression, PII-redaction) is intentionally not + invoked here — it runs once at the envelope level when the + signal/update goes over the wire. See module docstring. + """ + if self._client is not None: + return self._client.data_converter.payload_converter + return DataConverter.default.payload_converter + + def _encode_buffer(self, entries: list[tuple[str, Any]]) -> list[PublishEntry]: + """Convert buffered (topic, value) pairs to wire entries. + + Non-Payload values go through the sync payload converter so the + resulting ``Payload`` carries encoding metadata for + ``result_type=`` decode on the consumer side. Pre-built + Payloads bypass conversion. + """ + converter = self._payload_converter() + out: list[PublishEntry] = [] + for topic, value in entries: + if isinstance(value, Payload): + payload = value + else: + payload = converter.to_payloads([value])[0] + out.append(PublishEntry(topic=topic, data=_encode_payload(payload))) + return out + + async def _flush(self) -> None: + """Send buffered or pending messages to the workflow via signal. + + On failure, the pending batch and sequence are kept for retry. + Only advances the confirmed sequence on success. + """ + async with self._flush_lock: + if self._pending is not None: + # Retry path: check max_retry_duration + if ( + self._pending_since is not None + and time.monotonic() - self._pending_since + > self._max_retry_duration.total_seconds() + ): + # Advance confirmed sequence so the next batch gets + # a fresh sequence number. Without this, the next + # batch reuses pending_seq, which the workflow may + # have already accepted — causing silent dedup + # (data loss). See DropPendingFixed / + # SequenceFreshness in the design doc. + self._sequence = self._pending_seq + self._pending = None + self._pending_seq = 0 + self._pending_since = None + raise TimeoutError( + f"Flush retry exceeded max_retry_duration " + f"({self._max_retry_duration}). Pending batch dropped. " + f"If the signal was delivered, items are in the log. " + f"If not, they are lost." + ) + batch = self._pending + seq = self._pending_seq + elif self._buffer: + # New batch path. Encode before clearing the buffer so + # a payload-converter exception leaves the items in + # place for inspection or retry rather than silently + # dropping them. + batch = self._encode_buffer(self._buffer) + self._buffer = [] + seq = self._sequence + 1 + self._pending = batch + self._pending_seq = seq + self._pending_since = time.monotonic() + else: + return + + try: + # If the SDK ever exposes request_id on signal() and the + # server dedups it across CAN, pinning + # request_id=f"{publisher_id}:{seq}" here lets the + # workflow-side dedup go away. See DESIGN-v2 §"Replace + # workflow-side dedup with server-side request_id". + await self._handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=batch, + publisher_id=self._publisher_id, + sequence=seq, + ), + ) + # Success: advance confirmed sequence, clear pending + self._sequence = seq + self._pending = None + self._pending_seq = 0 + self._pending_since = None + except Exception: + # Pending stays set for retry on the next _flush() call + raise + + async def _run_flusher(self) -> None: + """Background task: wait for timer OR force_flush wakeup, then flush.""" + while True: + try: + await asyncio.wait_for( + self._flush_event.wait(), + timeout=self._batch_interval.total_seconds(), + ) + except asyncio.TimeoutError: + pass + self._flush_event.clear() + await self._flush() + + async def subscribe( + self, + topics: str | list[str] | None = None, + from_offset: int = 0, + *, + result_type: type | None = None, + poll_cooldown: timedelta = timedelta(milliseconds=100), + ) -> AsyncIterator[PubSubItem]: + """Async iterator that polls for new items. + + Automatically follows continue-as-new chains when the client + was created via :py:meth:`create`. + + Args: + topics: Topic filter. A single topic name, a list of topic + names, or None. None or empty list means all topics. + from_offset: Global offset to start reading from. + result_type: Optional target type. When provided, each + yielded :class:`PubSubItem` has its ``data`` decoded + via the client's sync payload converter to the + specified type. When omitted, ``data`` is the raw + ``temporalio.api.common.v1.Payload`` — useful for + heterogeneous topics where the caller dispatches on + ``Payload.metadata``. + poll_cooldown: Minimum interval between polls to avoid + overwhelming the workflow when items arrive faster + than the poll round-trip. Defaults to 100ms. + + Yields: + :class:`PubSubItem` for each matching item. + """ + topic_filter: list[str] + if topics is None: + topic_filter = [] + elif isinstance(topics, str): + topic_filter = [topics] + else: + topic_filter = topics + offset = from_offset + while True: + try: + result: PollResult = await self._handle.execute_update( + "__temporal_pubsub_poll", + PollInput(topics=topic_filter, from_offset=offset), + result_type=PollResult, + ) + except asyncio.CancelledError: + return + except WorkflowUpdateFailedError as e: + if e.cause and getattr(e.cause, "type", None) == "TruncatedOffset": + # Subscriber fell behind truncation. Retry from + # offset 0 which the mixin treats as "from the + # beginning of whatever exists" (i.e., from + # base_offset). + offset = 0 + continue + raise + except WorkflowUpdateRPCTimeoutOrCancelledError: + if await self._follow_continue_as_new(): + continue + return + converter = self._payload_converter() + for wire_item in result.items: + payload = _decode_payload(wire_item.data) + if result_type is not None: + data: Any = converter.from_payload(payload, result_type) + else: + data = payload + yield PubSubItem( + topic=wire_item.topic, + data=data, + offset=wire_item.offset, + ) + offset = result.next_offset + cooldown_secs = poll_cooldown.total_seconds() + if not result.more_ready and cooldown_secs > 0: + await asyncio.sleep(cooldown_secs) + + async def _follow_continue_as_new(self) -> bool: + """Check if the workflow continued-as-new and re-target the handle. + + Returns True if the handle was updated (caller should retry). + """ + if self._client is None: + return False + try: + desc = await self._handle.describe() + except Exception: + return False + if desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW: + self._handle = self._client.get_workflow_handle(self._workflow_id) + return True + return False + + async def get_offset(self) -> int: + """Query the current global offset (base_offset + log length).""" + return await self._handle.query("__temporal_pubsub_offset", result_type=int) diff --git a/temporalio/contrib/pubsub/_types.py b/temporalio/contrib/pubsub/_types.py new file mode 100644 index 000000000..de3929fe1 --- /dev/null +++ b/temporalio/contrib/pubsub/_types.py @@ -0,0 +1,134 @@ +"""Shared data types for the pub/sub contrib module. + +The user-facing ``data`` fields on :class:`PubSubItem` are +:class:`temporalio.api.common.v1.Payload`. Per-item values are +converted to ``Payload`` by the payload converter at publish time, and +the resulting bytes/metadata are preserved per item so subscribers can +decode with ``subscribe(result_type=T)``. The codec chain (encryption, +PII-redaction, compression) applies once at the outer signal/update +envelope level — not separately to each embedded item — so codec +behavior is symmetric between workflow-side and client-side +publishing. See ``DESIGN-v2.md`` §5 and +``docs/pubsub-payload-migration.md``. + +The wire representation (``PublishEntry``, ``_WireItem``) uses +base64-encoded ``Payload.SerializeToString()`` bytes because the default +JSON payload converter cannot serialize a ``Payload`` embedded inside a +dataclass (it only special-cases top-level Payloads on signal/update +args). Round-trip validated in +``tests/contrib/pubsub/test_payload_roundtrip_prototype.py``. +""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass, field +from typing import Any + +from temporalio.api.common.v1 import Payload + + +def _encode_payload(payload: Payload) -> str: # pyright: ignore[reportUnusedFunction] + """Wire format: base64(Payload.SerializeToString()).""" + return base64.b64encode(payload.SerializeToString()).decode("ascii") + + +def _decode_payload(wire: str) -> Payload: # pyright: ignore[reportUnusedFunction] + """Inverse of :func:`_encode_payload`.""" + payload = Payload() + payload.ParseFromString(base64.b64decode(wire)) + return payload + + +@dataclass +class PubSubItem: + """A single item in the pub/sub log. + + The ``data`` field is a :class:`temporalio.api.common.v1.Payload` + as stored by the mixin and yielded by + :meth:`PubSubClient.subscribe` when no ``result_type`` is given. + When ``result_type`` is passed to ``subscribe``, ``data`` holds the + decoded value of that type instead — the dataclass is typed as + ``Any`` to accommodate both. + + The ``offset`` field is populated at poll time from the item's + position in the global log. + """ + + topic: str + data: Any + offset: int = 0 + + +@dataclass +class PublishEntry: + """A single entry to publish via signal (wire type). + + ``data`` is base64-encoded ``Payload.SerializeToString()`` output — + see module docstring for why a nested ``Payload`` cannot be used + directly. + """ + + topic: str + data: str + + +@dataclass +class PublishInput: + """Signal payload: batch of entries to publish. + + Includes publisher_id and sequence to ensure exactly-once delivery. + """ + + items: list[PublishEntry] = field(default_factory=list) + publisher_id: str = "" + sequence: int = 0 + + +@dataclass +class PollInput: + """Update payload: request to poll for new items.""" + + topics: list[str] = field(default_factory=list) + from_offset: int = 0 + + +@dataclass +class _WireItem: + """Wire representation of a PubSubItem (base64 of serialized Payload).""" + + topic: str + data: str + offset: int = 0 + + +@dataclass +class PollResult: + """Update response: items matching the poll request. + + ``items`` use the wire representation. When ``more_ready`` is True, + the response was truncated to stay within size limits and the + subscriber should poll again immediately rather than applying a + cooldown delay. + """ + + items: list[_WireItem] = field(default_factory=list) + next_offset: int = 0 + more_ready: bool = False + + +@dataclass +class PubSubState: + """Serializable snapshot of pub/sub state for continue-as-new. + + The containing workflow input must type the field as + ``PubSubState | None``, not ``Any``, so the default data converter + can reconstruct the dataclass from JSON. + + Log items use the wire representation for serialization stability. + """ + + log: list[_WireItem] = field(default_factory=list) + base_offset: int = 0 + publisher_sequences: dict[str, int] = field(default_factory=dict) + publisher_last_seen: dict[str, float] = field(default_factory=dict) diff --git a/tests/contrib/google_adk_agents/test_adk_streaming.py b/tests/contrib/google_adk_agents/test_adk_streaming.py new file mode 100644 index 000000000..8f257a398 --- /dev/null +++ b/tests/contrib/google_adk_agents/test_adk_streaming.py @@ -0,0 +1,193 @@ +"""Integration tests for ADK streaming support. + +Verifies that the streaming model activity publishes raw ``LlmResponse`` +chunks via the PubSub broker and that non-streaming mode remains +backward-compatible. +""" + +import asyncio +import logging +import uuid +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from google.adk import Agent +from google.adk.agents.run_config import RunConfig, StreamingMode +from google.adk.models import BaseLlm, LLMRegistry +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.runners import InMemoryRunner +from google.genai.types import Content, Part + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.google_adk_agents import GoogleAdkPlugin, TemporalModel +from temporalio.contrib.pubsub import PubSub, PubSubClient +from temporalio.worker import Worker + +logger = logging.getLogger(__name__) + + +class StreamingTestModel(BaseLlm): + """Test model that yields multiple partial responses to simulate streaming.""" + + @classmethod + def supported_models(cls) -> list[str]: + return ["streaming_test_model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + yield LlmResponse(content=Content(role="model", parts=[Part(text="Hello ")])) + yield LlmResponse(content=Content(role="model", parts=[Part(text="world!")])) + + +@workflow.defn +class StreamingAdkWorkflow: + """Test workflow that opts into streaming via RunConfig.streaming_mode.""" + + @workflow.init + def __init__(self, prompt: str) -> None: + self.pubsub = PubSub() + + @workflow.run + async def run(self, prompt: str) -> str: + model = TemporalModel("streaming_test_model") + agent = Agent( + name="test_agent", + model=model, + instruction="You are a test agent.", + ) + + runner = InMemoryRunner(agent=agent, app_name="test-app") + session = await runner.session_service.create_session( + app_name="test-app", user_id="test" + ) + + final_text = "" + async for event in runner.run_async( + user_id="test", + session_id=session.id, + new_message=Content(role="user", parts=[Part(text=prompt)]), + run_config=RunConfig(streaming_mode=StreamingMode.SSE), + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + final_text = part.text + + return final_text + + +@workflow.defn +class NonStreamingAdkWorkflow: + """Test workflow without streaming.""" + + @workflow.run + async def run(self, prompt: str) -> str: + model = TemporalModel("streaming_test_model") + agent = Agent( + name="test_agent", + model=model, + instruction="You are a test agent.", + ) + + runner = InMemoryRunner(agent=agent, app_name="test-app") + session = await runner.session_service.create_session( + app_name="test-app", user_id="test" + ) + + final_text = "" + async for event in runner.run_async( + user_id="test", + session_id=session.id, + new_message=Content(role="user", parts=[Part(text=prompt)]), + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + final_text = part.text + + return final_text + + +@pytest.mark.asyncio +async def test_streaming_publishes_events(client: Client): + """Streaming activity publishes raw LlmResponse chunks to the topic.""" + LLMRegistry.register(StreamingTestModel) + + new_config = client.config() + new_config["plugins"] = [GoogleAdkPlugin()] + client = Client(**new_config) + + workflow_id = f"adk-streaming-test-{uuid.uuid4()}" + + async with Worker( + client, + task_queue="adk-streaming-test", + workflows=[StreamingAdkWorkflow], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + StreamingAdkWorkflow.run, + "Hello", + id=workflow_id, + task_queue="adk-streaming-test", + execution_timeout=timedelta(seconds=30), + ) + + pubsub = PubSubClient.create(client, workflow_id) + responses: list[LlmResponse] = [] + + async def collect_events() -> None: + async for item in pubsub.subscribe( + ["events"], + from_offset=0, + result_type=LlmResponse, + poll_cooldown=timedelta(milliseconds=50), + ): + responses.append(item.data) + if len(responses) >= 2: + break + + collect_task = asyncio.create_task(collect_events()) + result = await handle.result() + await asyncio.wait_for(collect_task, timeout=10.0) + + assert result is not None + + texts: list[str] = [] + for r in responses: + if r.content and r.content.parts: + for part in r.content.parts: + if part.text: + texts.append(part.text) + assert texts == ["Hello ", "world!"], f"Unexpected text deltas: {texts}" + + +@pytest.mark.asyncio +async def test_non_streaming_backward_compatible(client: Client): + """Verify non-streaming mode still works (backward compatibility).""" + LLMRegistry.register(StreamingTestModel) + + new_config = client.config() + new_config["plugins"] = [GoogleAdkPlugin()] + client = Client(**new_config) + + async with Worker( + client, + task_queue="adk-non-streaming-test", + workflows=[NonStreamingAdkWorkflow], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + NonStreamingAdkWorkflow.run, + "Hello", + id=f"adk-non-streaming-test-{uuid.uuid4()}", + task_queue="adk-non-streaming-test", + execution_timeout=timedelta(seconds=30), + ) + result = await handle.result() + + assert result is not None diff --git a/tests/contrib/openai_agents/test_openai_streaming.py b/tests/contrib/openai_agents/test_openai_streaming.py new file mode 100644 index 000000000..80917f21e --- /dev/null +++ b/tests/contrib/openai_agents/test_openai_streaming.py @@ -0,0 +1,376 @@ +"""Integration tests for OpenAI Agents streaming support. + +Streaming is opt-in via ``Runner.run_streamed``. Events flow back to the +workflow through ``RunResultStreaming.stream_events()`` (in batch after +each model activity completes) and to external consumers in real time +via the configured pub/sub topic. +""" + +import asyncio +import logging +import uuid +from collections.abc import AsyncIterator +from datetime import timedelta +from typing import Any + +import pytest +from agents import ( + Agent, + AgentOutputSchemaBase, + Handoff, + Model, + ModelResponse, + ModelSettings, + ModelTracing, + Runner, + Tool, + TResponseInputItem, + Usage, +) +from agents.items import TResponseStreamEvent +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseTextConfig, + ResponseTextDeltaEvent, + ResponseUsage, +) +from openai.types.responses.response_usage import ( + InputTokensDetails, + OutputTokensDetails, +) +from openai.types.shared.response_format_text import ResponseFormatText + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.openai_agents import ModelActivityParameters +from temporalio.contrib.openai_agents.testing import AgentEnvironment +from temporalio.contrib.pubsub import PubSub, PubSubClient +from tests.helpers import new_worker + +logger = logging.getLogger(__name__) + + +class StreamingTestModel(Model): + """Test model that yields text deltas followed by a ResponseCompletedEvent.""" + + __test__ = False + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + **kwargs: Any, + ) -> ModelResponse: + return ModelResponse( + output=[ + ResponseOutputMessage( + id="msg_test", + content=[ + ResponseOutputText( + text="Hello world!", + annotations=[], + type="output_text", + logprobs=[], + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + **kwargs: Any, + ) -> AsyncIterator[TResponseStreamEvent]: + # Yield text deltas + yield ResponseTextDeltaEvent( + content_index=0, + delta="Hello ", + item_id="item1", + output_index=0, + sequence_number=0, + type="response.output_text.delta", + logprobs=[], + ) + yield ResponseTextDeltaEvent( + content_index=0, + delta="world!", + item_id="item1", + output_index=0, + sequence_number=1, + type="response.output_text.delta", + logprobs=[], + ) + + # Yield the final completed event + response = Response( + id="resp_test", + created_at=0, + error=None, + incomplete_details=None, + instructions=None, + metadata={}, + model="test", + object="response", + output=[ + ResponseOutputMessage( + id="msg_test", + content=[ + ResponseOutputText( + text="Hello world!", + annotations=[], + type="output_text", + logprobs=[], + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + parallel_tool_calls=True, + temperature=1.0, + tool_choice="auto", + tools=[], + top_p=1.0, + status="completed", + text=ResponseTextConfig(format=ResponseFormatText(type="text")), + truncation="disabled", + usage=ResponseUsage( + input_tokens=10, + output_tokens=5, + total_tokens=15, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), + ) + yield ResponseCompletedEvent( + response=response, sequence_number=2, type="response.completed" + ) + + +@workflow.defn +class StreamingOpenAIWorkflow: + """Test workflow that opts into streaming via ``Runner.run_streamed``. + + Workflow code consumes events from ``stream_events()`` and exposes + the seen event types via a query so the test can verify both the + workflow-side iteration and the pub/sub side channel observe the + same events. + """ + + @workflow.init + def __init__(self, prompt: str) -> None: + self.pubsub = PubSub() + self.workflow_event_types: list[str] = [] + + @workflow.run + async def run(self, prompt: str) -> str: + agent = Agent[None]( + name="Assistant", + instructions="You are a test agent.", + ) + result = Runner.run_streamed(starting_agent=agent, input=prompt) + async for event in result.stream_events(): + raw = getattr(event, "data", None) + event_type = getattr(raw, "type", None) + if event_type is not None: + self.workflow_event_types.append(event_type) + return result.final_output + + @workflow.query + def get_workflow_event_types(self) -> list[str]: + return self.workflow_event_types + + +@workflow.defn +class NonStreamingOpenAIWorkflow: + """Test workflow that uses the non-streaming Runner.run path.""" + + @workflow.run + async def run(self, prompt: str) -> str: + agent = Agent[None]( + name="Assistant", + instructions="You are a test agent.", + ) + result = await Runner.run(starting_agent=agent, input=prompt) + return result.final_output + + +@workflow.defn +class StreamingNoPubSubWorkflow: + """Test workflow that uses run_streamed without a PubSub broker.""" + + @workflow.run + async def run(self, prompt: str) -> str: + agent = Agent[None]( + name="Assistant", + instructions="You are a test agent.", + ) + result = Runner.run_streamed(starting_agent=agent, input=prompt) + async for _ in result.stream_events(): + pass + return result.final_output + + +@pytest.mark.asyncio +async def test_streaming_publishes_raw_events(client: Client): + """Both the workflow consumer (via stream_events) and the pubsub + topic see the same native OpenAI events, in order, with no + normalization.""" + async with AgentEnvironment( + model=StreamingTestModel(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + ) as env: + client = env.applied_on_client(client) + workflow_id = f"openai-streaming-test-{uuid.uuid4()}" + + async with new_worker( + client, StreamingOpenAIWorkflow, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingOpenAIWorkflow.run, + "Hello", + id=workflow_id, + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + + pubsub = PubSubClient.create(client, workflow_id) + published: list[TResponseStreamEvent] = [] + + async def collect_events() -> None: + async for item in pubsub.subscribe( + ["events"], + from_offset=0, + # TResponseStreamEvent is a discriminated union + # (Annotated[..., Discriminator]); Pydantic decodes + # it via TypeAdapter at runtime, but pyright sees + # ``Annotated`` rather than ``type``. + result_type=TResponseStreamEvent, # type: ignore[arg-type] + poll_cooldown=timedelta(milliseconds=50), + ): + published.append(item.data) + if item.data.type == "response.completed": + break + + collect_task = asyncio.create_task(collect_events()) + result = await handle.result() + await asyncio.wait_for(collect_task, timeout=10.0) + + workflow_event_types = await handle.query( + StreamingOpenAIWorkflow.get_workflow_event_types + ) + + assert result == "Hello world!" + + published_types = [e.type for e in published] + assert published_types == [ + "response.output_text.delta", + "response.output_text.delta", + "response.completed", + ], f"Unexpected pub/sub event sequence: {published_types}" + + deltas = [e.delta for e in published if e.type == "response.output_text.delta"] + assert deltas == ["Hello ", "world!"] + + # Workflow-side iteration sees the same model events. + assert "response.output_text.delta" in workflow_event_types + assert "response.completed" in workflow_event_types + + +@pytest.mark.asyncio +async def test_non_streaming_path(client: Client): + """Runner.run still uses the non-streaming activity.""" + model = StreamingTestModel() + async with AgentEnvironment( + model=model, + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, + NonStreamingOpenAIWorkflow, + max_cached_workflows=0, + ) as worker: + result = await client.execute_workflow( + NonStreamingOpenAIWorkflow.run, + "Hello", + id=f"openai-non-streaming-test-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + + assert result == "Hello world!" + + +class TruncatedStreamingTestModel(Model): + """Fake model whose stream ends without a ResponseCompletedEvent.""" + + __test__ = False + + async def get_response(self, *a: Any, **kw: Any) -> ModelResponse: + raise NotImplementedError + + async def stream_response( + self, *a: Any, **kw: Any + ) -> AsyncIterator[TResponseStreamEvent]: + yield ResponseTextDeltaEvent( + content_index=0, + delta="partial", + item_id="item1", + output_index=0, + sequence_number=0, + type="response.output_text.delta", + logprobs=[], + ) + + +@pytest.mark.asyncio +async def test_streaming_no_pubsub_topic(client: Client): + """Setting streaming_event_topic=None disables publishing; the + workflow can still consume events via stream_events().""" + async with AgentEnvironment( + model=StreamingTestModel(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + streaming_event_topic=None, + ), + ) as env: + client = env.applied_on_client(client) + async with new_worker( + client, StreamingNoPubSubWorkflow, max_cached_workflows=0 + ) as worker: + result = await client.execute_workflow( + StreamingNoPubSubWorkflow.run, + "Hi", + id=f"openai-streaming-no-pubsub-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + + assert result == "Hello world!" diff --git a/tests/contrib/pubsub/__init__.py b/tests/contrib/pubsub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/pubsub/test_payload_roundtrip_prototype.py b/tests/contrib/pubsub/test_payload_roundtrip_prototype.py new file mode 100644 index 000000000..b020d3e4f --- /dev/null +++ b/tests/contrib/pubsub/test_payload_roundtrip_prototype.py @@ -0,0 +1,145 @@ +"""Prototype tests that de-risked the pubsub bytes -> Payload migration. + +The migration doc (``docs/pubsub-payload-migration.md``) flagged two +load-bearing questions, answered empirically here: + +1. Does the default JSON converter handle ``Payload`` embedded in a + dataclass? **No** — serialization fails with ``TypeError``. This + rules out a naive nested-Payload wire format. +2. Does a proto-serialized ``Payload`` inside a dataclass round-trip? + **Yes**. This is the wire format the migration adopts: base64 of + ``Payload.SerializeToString()`` inside ``PublishEntry``/``_WireItem``, + surfacing ``Payload`` (or a decoded value via ``result_type=``) at + the user API. + +Kept as a regression guard: if a future payload converter change makes +(1) succeed, the migration could in principle reclaim a zero-copy wire +format; if (2) regresses, the migration breaks. +""" + +from __future__ import annotations + +import base64 +import uuid +from dataclasses import dataclass, field + +import pytest + +from temporalio import workflow +from temporalio.api.common.v1 import Payload +from temporalio.client import Client +from tests.helpers import new_worker + + +@dataclass +class NestedPayloadEnvelope: + items: list[Payload] = field(default_factory=list) + + +@dataclass +class SerializedEntry: + topic: str + data: str # base64(Payload.SerializeToString()) + + +@dataclass +class SerializedEnvelope: + items: list[SerializedEntry] = field(default_factory=list) + + +@workflow.defn +class NestedPayloadWorkflow: + def __init__(self) -> None: + self._received: NestedPayloadEnvelope | None = None + + @workflow.signal + def receive(self, envelope: NestedPayloadEnvelope) -> None: + self._received = envelope + + @workflow.query + def decoded_strings(self) -> list[str]: + assert self._received is not None + conv = workflow.payload_converter() + return [conv.from_payload(p, str) for p in self._received.items] + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._received is not None) + + +@workflow.defn +class SerializedPayloadWorkflow: + def __init__(self) -> None: + self._received: SerializedEnvelope | None = None + + @workflow.signal + def receive(self, envelope: SerializedEnvelope) -> None: + self._received = envelope + + @workflow.query + def decoded_strings(self) -> list[str]: + assert self._received is not None + conv = workflow.payload_converter() + out: list[str] = [] + for entry in self._received.items: + p = Payload() + p.ParseFromString(base64.b64decode(entry.data)) + out.append(conv.from_payload(p, str)) + return out + + @workflow.query + def topics(self) -> list[str]: + assert self._received is not None + return [e.topic for e in self._received.items] + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._received is not None) + + +@pytest.mark.asyncio +async def test_nested_payload_in_dataclass_fails(client: Client) -> None: + """Confirm the load-bearing negative result: Payload inside dataclass doesn't serialize.""" + conv = client.data_converter.payload_converter + payloads = [conv.to_payloads([v])[0] for v in ["hello", "world"]] + envelope = NestedPayloadEnvelope(items=payloads) + + async with new_worker(client, NestedPayloadWorkflow) as worker: + handle = await client.start_workflow( + NestedPayloadWorkflow.run, + id=f"nested-payload-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + with pytest.raises(TypeError, match="Payload is not JSON serializable"): + await handle.signal(NestedPayloadWorkflow.receive, envelope) + await handle.terminate() + + +@pytest.mark.asyncio +async def test_serialized_payload_fallback_round_trips(client: Client) -> None: + """Proto-serialize Payload -> base64 -> dataclass round-trips through signal.""" + conv = client.data_converter.payload_converter + originals = ["hello", "world", "payload"] + payloads = [conv.to_payloads([v])[0] for v in originals] + envelope = SerializedEnvelope( + items=[ + SerializedEntry( + topic=f"t{i}", + data=base64.b64encode(p.SerializeToString()).decode("ascii"), + ) + for i, p in enumerate(payloads) + ] + ) + + async with new_worker(client, SerializedPayloadWorkflow) as worker: + handle = await client.start_workflow( + SerializedPayloadWorkflow.run, + id=f"serialized-payload-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.signal(SerializedPayloadWorkflow.receive, envelope) + decoded = await handle.query(SerializedPayloadWorkflow.decoded_strings) + assert decoded == originals + topics = await handle.query(SerializedPayloadWorkflow.topics) + assert topics == ["t0", "t1", "t2"] + await handle.result() diff --git a/tests/contrib/pubsub/test_pubsub.py b/tests/contrib/pubsub/test_pubsub.py new file mode 100644 index 000000000..9a4ec9c9d --- /dev/null +++ b/tests/contrib/pubsub/test_pubsub.py @@ -0,0 +1,2030 @@ +"""E2E integration tests for temporalio.contrib.pubsub.""" + +from __future__ import annotations + +import asyncio +import sys +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Any +from unittest.mock import patch + +if sys.version_info >= (3, 11): + from asyncio import timeout as _async_timeout # pyright: ignore[reportUnreachable] +else: + from async_timeout import ( # pyright: ignore[reportUnreachable] + timeout as _async_timeout, + ) + +import google.protobuf.duration_pb2 +import nexusrpc +import nexusrpc.handler +import pytest + +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice.v1 +import temporalio.api.workflowservice.v1 +from temporalio import activity, nexus, workflow +from temporalio.client import Client, WorkflowHandle, WorkflowUpdateFailedError +from temporalio.contrib.pubsub import ( + PollInput, + PollResult, + PublishEntry, + PublishInput, + PubSub, + PubSubClient, + PubSubItem, + PubSubState, +) +from temporalio.contrib.pubsub._types import _encode_payload +from temporalio.converter import DataConverter +from temporalio.exceptions import ApplicationError +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers import assert_eq_eventually, new_worker +from tests.helpers.nexus import make_nexus_endpoint_name + + +def _wire_bytes(data: bytes) -> str: + """Build a PublishEntry.data string from raw bytes. + + Mirrors what :class:`PubSubClient` produces on the encode path: + default payload converter turns the bytes into a ``Payload``, which + is then proto-serialized and base64-encoded for the wire. + """ + payload = DataConverter.default.payload_converter.to_payloads([data])[0] + return _encode_payload(payload) + + +# --------------------------------------------------------------------------- +# Test workflows (must be module-level, not local classes) +# --------------------------------------------------------------------------- + + +@workflow.defn +class BasicPubSubWorkflow: + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class ActivityPublishWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_items", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + self.pubsub.publish("status", b"activity_done") + await workflow.wait_condition(lambda: self._closed) + + +@dataclass +class AgentEvent: + kind: str + payload: dict[str, Any] + + +@workflow.defn +class StructuredPublishWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.pubsub.publish("events", AgentEvent(kind="tick", payload={"i": i})) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class WorkflowSidePublishWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.pubsub.publish("events", f"item-{i}".encode()) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class MultiTopicWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_multi_topic", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class InterleavedWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + self.pubsub.publish("status", b"started") + await workflow.execute_activity( + "publish_items", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + self.pubsub.publish("status", b"done") + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class PriorityWorkflow: + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self) -> None: + await workflow.execute_activity( + "publish_with_priority", + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class FlushOnExitWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_batch_test", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class MaxBatchWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.query + def publisher_sequences(self) -> dict[str, int]: + return dict(self.pubsub._publisher_sequences) + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_with_max_batch", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + self.pubsub.publish("status", b"activity_done") + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class LatePubSubWorkflow: + """Calls PubSub() from @workflow.run, not from @workflow.init. + + The constructor inspects the caller's frame and requires the + function name to be ``__init__``; called from ``run``, it must + raise ``RuntimeError``. The workflow returns the error message so + the test can assert on it without forcing a workflow task failure. + """ + + @workflow.run + async def run(self) -> str: + try: + PubSub() + except RuntimeError as e: + return str(e) + return "no error raised" + + +@workflow.defn +class DoubleInitWorkflow: + """Calls PubSub() twice from @workflow.init. + + The first call succeeds; the second must raise RuntimeError because + the pub/sub signal handler is already registered. The workflow + stashes the error message so the test can assert on it without + forcing a workflow task failure. + """ + + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + self._closed = False + self.double_init_error: str | None = None + try: + PubSub() + except RuntimeError as e: + self.double_init_error = str(e) + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.query + def get_double_init_error(self) -> str | None: + return self.double_init_error + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +# --------------------------------------------------------------------------- +# Activities +# --------------------------------------------------------------------------- + + +@activity.defn(name="publish_items") +async def publish_items(count: int) -> None: + client = PubSubClient.from_activity(batch_interval=timedelta(milliseconds=500)) + async with client: + for i in range(count): + activity.heartbeat() + client.publish("events", f"item-{i}".encode()) + + +@activity.defn(name="publish_multi_topic") +async def publish_multi_topic(count: int) -> None: + topics = ["a", "b", "c"] + client = PubSubClient.from_activity(batch_interval=timedelta(milliseconds=500)) + async with client: + for i in range(count): + activity.heartbeat() + topic = topics[i % len(topics)] + client.publish(topic, f"{topic}-{i}".encode()) + + +@activity.defn(name="publish_with_priority") +async def publish_with_priority() -> None: + # Long batch_interval AND long post-publish hold ensure that only a + # working force_flush wakeup can deliver items before __aexit__ flushes. + # The hold is deliberately much longer than the test's collect timeout + # so a regression (force_flush no-op) surfaces as a missing item rather + # than flaking on slow CI. + client = PubSubClient.from_activity(batch_interval=timedelta(seconds=60)) + async with client: + client.publish("events", b"normal-0") + client.publish("events", b"normal-1") + client.publish("events", b"priority", force_flush=True) + for _ in range(100): + activity.heartbeat() + await asyncio.sleep(0.1) + + +@activity.defn(name="publish_batch_test") +async def publish_batch_test(count: int) -> None: + client = PubSubClient.from_activity(batch_interval=timedelta(seconds=60)) + async with client: + for i in range(count): + activity.heartbeat() + client.publish("events", f"item-{i}".encode()) + + +@activity.defn(name="publish_with_max_batch") +async def publish_with_max_batch(count: int) -> None: + client = PubSubClient.from_activity( + batch_interval=timedelta(seconds=60), max_batch_size=3 + ) + async with client: + for i in range(count): + activity.heartbeat() + client.publish("events", f"item-{i}".encode()) + # Yield so the flusher task can run when max_batch_size triggers + # _flush_event. Real workloads (e.g. agents awaiting LLM streams) + # yield constantly; a tight loop with no awaits would never let + # the flusher fire and would collapse back to exit-only flushing. + await asyncio.sleep(0) + # Long batch_interval ensures only max_batch_size triggers flushes. + # Context manager exit flushes any remainder. + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _is_different_run( + old_handle: WorkflowHandle[Any, Any], + new_handle: WorkflowHandle[Any, Any], +) -> bool: + """Check if new_handle points to a different run than old_handle.""" + try: + desc = await new_handle.describe() + return desc.run_id != old_handle.result_run_id + except Exception: + return False + + +async def collect_items( + client: Client, + handle: WorkflowHandle[Any, Any], + topics: list[str] | None, + from_offset: int, + expected_count: int, + timeout: float = 15.0, + *, + result_type: type | None = bytes, +) -> list[PubSubItem]: + """Subscribe and collect exactly expected_count items, with timeout. + + Default ``result_type=bytes`` matches the bytes-oriented tests that + compare ``item.data`` against literal byte strings. Pass + ``result_type=None`` to receive raw ``Payload`` objects. + """ + pubsub = PubSubClient.create(client, handle.id) + items: list[PubSubItem] = [] + try: + async with _async_timeout(timeout): + async for item in pubsub.subscribe( + topics=topics, + from_offset=from_offset, + poll_cooldown=timedelta(0), + result_type=result_type, + ): + items.append(item) + if len(items) >= expected_count: + break + except asyncio.TimeoutError: + pass + return items + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_activity_publish_and_subscribe(client: Client) -> None: + """Activity publishes items, external client subscribes and receives them.""" + count = 10 + async with new_worker( + client, + ActivityPublishWorkflow, + activities=[publish_items], + ) as worker: + handle = await client.start_workflow( + ActivityPublishWorkflow.run, + count, + id=f"pubsub-basic-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # Collect activity items + the "activity_done" status item + items = await collect_items(client, handle, None, 0, count + 1) + assert len(items) == count + 1 + + # Check activity items + for i in range(count): + assert items[i].topic == "events" + assert items[i].data == f"item-{i}".encode() + + # Check workflow-side status item + assert items[count].topic == "status" + assert items[count].data == b"activity_done" + + await handle.signal(ActivityPublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_structured_type_round_trip(client: Client) -> None: + """Workflow publishes dataclass values; subscriber decodes via result_type.""" + count = 4 + async with new_worker(client, StructuredPublishWorkflow) as worker: + handle = await client.start_workflow( + StructuredPublishWorkflow.run, + count, + id=f"pubsub-structured-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + items = await collect_items( + client, handle, None, 0, count, result_type=AgentEvent + ) + assert len(items) == count + for i, item in enumerate(items): + assert isinstance(item.data, AgentEvent) + assert item.data == AgentEvent(kind="tick", payload={"i": i}) + + await handle.signal(StructuredPublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_topic_filtering(client: Client) -> None: + """Publish to multiple topics, subscribe with filter.""" + count = 9 # 3 per topic + async with new_worker( + client, + MultiTopicWorkflow, + activities=[publish_multi_topic], + ) as worker: + handle = await client.start_workflow( + MultiTopicWorkflow.run, + count, + id=f"pubsub-filter-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Subscribe to topic "a" only — should get 3 items + a_items = await collect_items(client, handle, ["a"], 0, 3) + assert len(a_items) == 3 + assert all(item.topic == "a" for item in a_items) + + # Subscribe to ["a", "c"] — should get 6 items + ac_items = await collect_items(client, handle, ["a", "c"], 0, 6) + assert len(ac_items) == 6 + assert all(item.topic in ("a", "c") for item in ac_items) + + # Subscribe to all (None) — should get all 9 + all_items = await collect_items(client, handle, None, 0, 9) + assert len(all_items) == 9 + + await handle.signal(MultiTopicWorkflow.close) + + +@pytest.mark.asyncio +async def test_subscribe_from_offset_and_per_item_offsets(client: Client) -> None: + """Subscribe from zero and non-zero offsets; each item carries its global offset.""" + count = 5 + async with new_worker( + client, + WorkflowSidePublishWorkflow, + ) as worker: + handle = await client.start_workflow( + WorkflowSidePublishWorkflow.run, + count, + id=f"pubsub-offset-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Subscribe from offset 0 — all items, offsets 0..count-1 + all_items = await collect_items(client, handle, None, 0, count) + assert len(all_items) == count + for i, item in enumerate(all_items): + assert item.offset == i + assert item.data == f"item-{i}".encode() + + # Subscribe from offset 3 — items 3, 4 with offsets 3, 4 + later_items = await collect_items(client, handle, None, 3, 2) + assert len(later_items) == 2 + assert later_items[0].offset == 3 + assert later_items[0].data == b"item-3" + assert later_items[1].offset == 4 + assert later_items[1].data == b"item-4" + + await handle.signal(WorkflowSidePublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_per_item_offsets_with_topic_filter(client: Client) -> None: + """Per-item offsets are global (not per-topic) even when filtering.""" + count = 9 # 3 per topic (a, b, c round-robin) + async with new_worker( + client, + MultiTopicWorkflow, + activities=[publish_multi_topic], + ) as worker: + handle = await client.start_workflow( + MultiTopicWorkflow.run, + count, + id=f"pubsub-item-offset-filter-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Subscribe to topic "a" only — items are at global offsets 0, 3, 6 + a_items = await collect_items(client, handle, ["a"], 0, 3) + assert len(a_items) == 3 + assert a_items[0].offset == 0 + assert a_items[1].offset == 3 + assert a_items[2].offset == 6 + + # Subscribe to topic "b" — items are at global offsets 1, 4, 7 + b_items = await collect_items(client, handle, ["b"], 0, 3) + assert len(b_items) == 3 + assert b_items[0].offset == 1 + assert b_items[1].offset == 4 + assert b_items[2].offset == 7 + + await handle.signal(MultiTopicWorkflow.close) + + +@pytest.mark.asyncio +async def test_poll_truncated_offset_returns_application_error(client: Client) -> None: + """Polling a truncated offset raises ApplicationError (not ValueError) + and does not crash the workflow task.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 5, + id=f"pubsub-trunc-error-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Truncate up to offset 3 via update — completion is explicit. + await handle.execute_update("truncate", 3) + + # Poll from offset 1 (truncated) — should get ApplicationError, + # NOT crash the workflow task. Catching WorkflowUpdateFailedError is + # sufficient to prove the handler raised ApplicationError: Temporal's + # update protocol completes the update with this error only when the + # handler raises ApplicationError. A bare ValueError (or any other + # exception) would fail the workflow task instead, causing + # execute_update to hang — not raise. The follow-up collect_items + # below proves the workflow task wasn't poisoned. + with pytest.raises(WorkflowUpdateFailedError) as exc_info: + await handle.execute_update( + "__temporal_pubsub_poll", + PollInput(topics=[], from_offset=1), + result_type=PollResult, + ) + cause = exc_info.value.cause + assert isinstance(cause, ApplicationError) + assert cause.type == "TruncatedOffset" + + # Workflow should still be usable — poll from valid offset 3 + items = await collect_items(client, handle, None, 3, 2) + assert len(items) == 2 + assert items[0].offset == 3 + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_truncate_past_end_raises_application_error(client: Client) -> None: + """truncate() with an offset past the log end raises ApplicationError + (type=TruncateOutOfRange) — the update surfaces as a clean failure + without poisoning the workflow task.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 2, + id=f"pubsub-trunc-oor-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Only 2 items exist; asking to truncate to offset 5 is out of range. + with pytest.raises(WorkflowUpdateFailedError) as exc_info: + await handle.execute_update("truncate", 5) + cause = exc_info.value.cause + assert isinstance(cause, ApplicationError) + assert cause.type == "TruncateOutOfRange" + + # Workflow task wasn't poisoned — a valid poll still completes. + items = await collect_items(client, handle, None, 0, 2) + assert len(items) == 2 + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_subscribe_recovers_from_truncation(client: Client) -> None: + """subscribe() auto-recovers when offset falls behind truncation.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 5, + id=f"pubsub-trunc-recover-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Truncate first 3. The update returns after the handler completes. + await handle.execute_update("truncate", 3) + + # subscribe from offset 1 (truncated) — should auto-recover + # and deliver items from base_offset (3) + pubsub = PubSubClient(handle) + items: list[PubSubItem] = [] + try: + async with _async_timeout(5): + async for item in pubsub.subscribe( + from_offset=1, poll_cooldown=timedelta(0), result_type=bytes + ): + items.append(item) + if len(items) >= 2: + break + except asyncio.TimeoutError: + pass + assert len(items) == 2 + assert items[0].offset == 3 + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_workflow_and_activity_publish_interleaved(client: Client) -> None: + """Workflow publishes status events around activity publishing.""" + count = 5 + async with new_worker( + client, + InterleavedWorkflow, + activities=[publish_items], + ) as worker: + handle = await client.start_workflow( + InterleavedWorkflow.run, + count, + id=f"pubsub-interleave-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Total: 1 (started) + count (activity) + 1 (done) = count + 2 + items = await collect_items(client, handle, None, 0, count + 2) + assert len(items) == count + 2 + + # First item is workflow-side "started" + assert items[0].topic == "status" + assert items[0].data == b"started" + + # Middle items are from activity + for i in range(count): + assert items[i + 1].topic == "events" + assert items[i + 1].data == f"item-{i}".encode() + + # Last item is workflow-side "done" + assert items[count + 1].topic == "status" + assert items[count + 1].data == b"done" + + await handle.signal(InterleavedWorkflow.close) + + +@pytest.mark.asyncio +async def test_priority_flush(client: Client) -> None: + """Priority publish triggers immediate flush without waiting for timer.""" + async with new_worker( + client, + PriorityWorkflow, + activities=[publish_with_priority], + ) as worker: + handle = await client.start_workflow( + PriorityWorkflow.run, + id=f"pubsub-priority-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # If priority works, items arrive within milliseconds of the publish. + # The activity holds for ~10s after priority publish; this timeout + # gives plenty of margin for workflow/worker scheduling on slow CI + # while staying well below the activity hold so a regression (no + # priority wakeup) surfaces as a missing item, not a pass via + # __aexit__ flush. + items = await collect_items(client, handle, None, 0, 3, timeout=5.0) + assert len(items) == 3 + assert items[2].data == b"priority" + + await handle.signal(PriorityWorkflow.close) + + +@pytest.mark.asyncio +async def test_iterator_cancellation(client: Client) -> None: + """Cancelling a subscription iterator after it has yielded an item + completes cleanly.""" + async with new_worker( + client, + BasicPubSubWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-cancel-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Seed one item so the iterator provably reaches an active state + # before we cancel — no sleep-based wait. + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"seed"))] + ), + ) + + pubsub_client = PubSubClient.create(client, handle.id) + first_item = asyncio.Event() + items: list[PubSubItem] = [] + + async def subscribe_and_collect() -> None: + async for item in pubsub_client.subscribe( + from_offset=0, poll_cooldown=timedelta(0), result_type=bytes + ): + items.append(item) + first_item.set() + + task = asyncio.create_task(subscribe_and_collect()) + # Bounded wait so a subscribe regression fails fast instead of hanging. + async with _async_timeout(5): + await first_item.wait() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert len(items) == 1 + assert items[0].data == b"seed" + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_context_manager_flushes_on_exit(client: Client) -> None: + """Context manager exit flushes all buffered items.""" + count = 5 + async with new_worker( + client, + FlushOnExitWorkflow, + activities=[publish_batch_test], + ) as worker: + handle = await client.start_workflow( + FlushOnExitWorkflow.run, + count, + id=f"pubsub-flush-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Despite 60s batch interval, all items arrive because __aexit__ flushes + items = await collect_items(client, handle, None, 0, count, timeout=15.0) + assert len(items) == count + for i in range(count): + assert items[i].data == f"item-{i}".encode() + + await handle.signal(FlushOnExitWorkflow.close) + + +@pytest.mark.asyncio +async def test_explicit_flush_barrier(client: Client) -> None: + """``await client.flush()`` is a synchronization point. + + Verifies the documented contract: + 1. Returns immediately when the buffer is empty. + 2. After it returns, items published before the call are durable + on the workflow side (observable via ``get_offset()``) — even + when the timer-driven flush would not yet have fired. + 3. Calling it again after a successful flush is a no-op. + + Uses a 60s ``batch_interval`` so a regression where ``flush()`` + silently relies on the background timer surfaces as a hang + against the test's 5s timeout, not a slow pass. + """ + async with new_worker( + client, + BasicPubSubWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-flush-barrier-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + pubsub = PubSubClient.create( + client, handle.id, batch_interval=timedelta(seconds=60) + ) + + async with _async_timeout(5): + # 1. Empty-buffer flush is a no-op (must not block). + assert await pubsub.get_offset() == 0 + await pubsub.flush() + assert await pubsub.get_offset() == 0 + + # 2. Flush makes prior publishes visible without waiting on + # the 60s batch timer. + pubsub.publish("events", b"a") + pubsub.publish("events", b"b") + pubsub.publish("events", b"c") + await pubsub.flush() + assert await pubsub.get_offset() == 3 + + # 3. Second flush with no new items is a no-op. + await pubsub.flush() + assert await pubsub.get_offset() == 3 + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_concurrent_subscribers(client: Client) -> None: + """Two subscribers on different topics make interleaved progress. + + Publishes A-0, waits for subscriber A to observe it; publishes B-0, + waits for subscriber B to observe it. At this point both subscribers + have received exactly one item and are polling for their second, + so both subscriptions are provably in flight at the same time. + Then publishes A-1, B-1 the same way. A sequential execution (A drains + then B starts) cannot satisfy the ordering because B's first item + isn't published until after A has already received its first. + """ + async with new_worker( + client, + BasicPubSubWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-concurrent-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + pubsub = PubSubClient(handle) + a_items: list[PubSubItem] = [] + b_items: list[PubSubItem] = [] + a_got = [asyncio.Event(), asyncio.Event()] + b_got = [asyncio.Event(), asyncio.Event()] + + async def collect( + topic: str, + collected: list[PubSubItem], + events: list[asyncio.Event], + ) -> None: + async for item in pubsub.subscribe( + topics=[topic], + from_offset=0, + poll_cooldown=timedelta(0), + result_type=bytes, + ): + collected.append(item) + events[len(collected) - 1].set() + if len(collected) >= len(events): + break + + a_task = asyncio.create_task(collect("a", a_items, a_got)) + b_task = asyncio.create_task(collect("b", b_items, b_got)) + + async def publish(topic: str, data: bytes) -> None: + await handle.signal( + "__temporal_pubsub_publish", + PublishInput(items=[PublishEntry(topic=topic, data=_wire_bytes(data))]), + ) + + try: + async with _async_timeout(10): + await publish("a", b"a-0") + await a_got[0].wait() + await publish("b", b"b-0") + await b_got[0].wait() + # Both subscribers are now mid-subscription, each having + # seen one item and polling for the next. + await publish("a", b"a-1") + await a_got[1].wait() + await publish("b", b"b-1") + await b_got[1].wait() + + await asyncio.gather(a_task, b_task) + finally: + a_task.cancel() + b_task.cancel() + + assert [i.data for i in a_items] == [b"a-0", b"a-1"] + assert [i.data for i in b_items] == [b"b-0", b"b-1"] + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_max_batch_size(client: Client) -> None: + """max_batch_size triggers auto-flush without waiting for timer.""" + count = 7 # with max_batch_size=3: flushes at 3, 6, then remainder 1 on exit + async with new_worker( + client, + MaxBatchWorkflow, + activities=[publish_with_max_batch], + max_cached_workflows=0, + ) as worker: + handle = await client.start_workflow( + MaxBatchWorkflow.run, + count, + id=f"pubsub-maxbatch-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # count items from activity + 1 "activity_done" from workflow + items = await collect_items(client, handle, None, 0, count + 1, timeout=15.0) + assert len(items) == count + 1 + for i in range(count): + assert items[i].data == f"item-{i}".encode() + + # max_batch_size actually engages: at least one flush fires during + # the publish loop, so 7 items ship as >=2 signals. Without this + # assertion the test would pass even if max_batch_size were ignored + # and all 7 items went out in a single exit-time flush (batch_count + # == 1). Note: max_batch_size is a *trigger* threshold, not a cap — + # the flusher may take more items from the buffer than max_batch_size + # if more were added while a prior signal was in flight, so the exact + # batch count depends on interleaving. Asserting >= 2 is the + # non-flaky way to verify the mechanism is live. + seqs = await handle.query(MaxBatchWorkflow.publisher_sequences) + assert len(seqs) == 1, f"expected one publisher, got {seqs}" + (batch_count,) = seqs.values() + assert batch_count >= 2, ( + f"expected >=2 batches with max_batch_size=3 and 7 items, got " + f"{batch_count} — max_batch_size did not trigger a mid-loop flush" + ) + + await handle.signal(MaxBatchWorkflow.close) + + +@pytest.mark.asyncio +async def test_replay_safety(client: Client) -> None: + """Pub/sub broker survives workflow replay (max_cached_workflows=0).""" + async with new_worker( + client, + InterleavedWorkflow, + activities=[publish_items], + max_cached_workflows=0, + ) as worker: + handle = await client.start_workflow( + InterleavedWorkflow.run, + 5, + id=f"pubsub-replay-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # 1 (started) + 5 (activity) + 1 (done) = 7 + items = await collect_items(client, handle, None, 0, 7) + # Full ordered sequence — endpoint-only checks would miss mid-stream + # replay corruption (reordering, duplication, dropped items). + assert [i.data for i in items] == [ + b"started", + b"item-0", + b"item-1", + b"item-2", + b"item-3", + b"item-4", + b"done", + ] + assert [i.offset for i in items] == list(range(7)) + await handle.signal(InterleavedWorkflow.close) + + +@pytest.mark.asyncio +async def test_flush_retry_preserves_items_after_failures( + client: Client, +) -> None: + """After flush failures, a subsequent successful flush delivers all items + in publish order, exactly once. + + Exercises the retry code path behaviorally: simulated delivery failures + must not drop items, must not duplicate them on retry, and must not + reorder items published during the failed state. + """ + async with new_worker(client, BasicPubSubWorkflow) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-flush-retry-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + pubsub = PubSubClient(handle) + real_signal = handle.signal + fail_remaining = 2 + + async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any: + nonlocal fail_remaining + if fail_remaining > 0: + fail_remaining -= 1 + raise RuntimeError("simulated delivery failure") + return await real_signal(*args, **kwargs) + + with patch.object(handle, "signal", side_effect=maybe_failing_signal): + pubsub.publish("events", b"item-0") + pubsub.publish("events", b"item-1") + with pytest.raises(RuntimeError): + await pubsub._flush() + + # Publish more during the failed state — must not overtake the + # pending retry on eventual delivery. + pubsub.publish("events", b"item-2") + with pytest.raises(RuntimeError): + await pubsub._flush() + + # Third flush succeeds, delivering the pending retry batch. + await pubsub._flush() + # Fourth flush delivers the buffered "item-2". + await pubsub._flush() + + items = await collect_items(client, handle, None, 0, 3) + assert [i.data for i in items] == [b"item-0", b"item-1", b"item-2"] + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_flush_raises_after_max_retry_duration(client: Client) -> None: + """When max_retry_duration is exceeded, flush raises TimeoutError and the + client can resume publishing without losing subsequent items.""" + async with new_worker(client, BasicPubSubWorkflow) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-retry-expiry-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Inject a controllable clock into the client module. The client's + # retry check compares `time.monotonic() - _pending_since` against + # `max_retry_duration`, so advancing the clock between flushes makes + # the timeout fire deterministically regardless of wall-clock speed + # or clock resolution. + pubsub = PubSubClient(handle, max_retry_duration=timedelta(milliseconds=100)) + real_signal = handle.signal + fail_signals = True + + async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any: + if fail_signals: + raise RuntimeError("simulated failure") + return await real_signal(*args, **kwargs) + + clock = [0.0] + with ( + patch( + "temporalio.contrib.pubsub._client.time.monotonic", + side_effect=lambda: clock[0], + ), + patch.object(handle, "signal", side_effect=maybe_failing_signal), + ): + pubsub.publish("events", b"lost") + + # First flush fails and enters the pending-retry state. + with pytest.raises(RuntimeError): + await pubsub._flush() + + # Advance the clock well past max_retry_duration. + clock[0] = 10.0 + + # Next flush raises TimeoutError — the pending batch is abandoned. + with pytest.raises(TimeoutError, match="max_retry_duration"): + await pubsub._flush() + + # Stop failing signals; subsequent publishes must succeed. + fail_signals = False + pubsub.publish("events", b"kept") + await pubsub._flush() + + items = await collect_items(client, handle, None, 0, 1) + assert len(items) == 1 + assert items[0].data == b"kept" + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_dedup_rejects_duplicate_signal(client: Client) -> None: + """Workflow deduplicates signals with the same publisher_id + sequence.""" + async with new_worker( + client, + BasicPubSubWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-dedup-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Send a batch with publisher_id and sequence + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-0"))], + publisher_id="test-pub", + sequence=1, + ), + ) + + # Send the same sequence again — should be deduped + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"duplicate"))], + publisher_id="test-pub", + sequence=1, + ), + ) + + # Send a new sequence — should go through + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-1"))], + publisher_id="test-pub", + sequence=2, + ), + ) + + # Should have 2 items, not 3 (collect_items' update call acts as barrier) + items = await collect_items(client, handle, None, 0, 2) + assert len(items) == 2 + assert items[0].data == b"item-0" + assert items[1].data == b"item-1" + + # Verify offset is 2 (not 3) + pubsub_client = PubSubClient(handle) + offset = await pubsub_client.get_offset() + assert offset == 2 + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_double_init_raises(client: Client) -> None: + """Instantiating PubSub twice from @workflow.init raises RuntimeError. + + The first PubSub() registers the __temporal_pubsub_publish signal handler; the + second call detects the existing handler and raises rather than + silently overwriting it. + """ + async with new_worker(client, DoubleInitWorkflow) as worker: + handle = await client.start_workflow( + DoubleInitWorkflow.run, + id=f"pubsub-double-init-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + err = await handle.query(DoubleInitWorkflow.get_double_init_error) + assert err is not None + assert "already registered" in err + await handle.signal(DoubleInitWorkflow.close) + + +@pytest.mark.asyncio +async def test_pubsub_outside_init_raises(client: Client) -> None: + """Constructing PubSub outside @workflow.init raises RuntimeError. + + The workflow calls PubSub() from @workflow.run; the caller-frame + guard must reject the call because the caller's function name is + ``run``, not ``__init__``. + """ + async with new_worker(client, LatePubSubWorkflow) as worker: + result = await client.execute_workflow( + LatePubSubWorkflow.run, + id=f"pubsub-late-init-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + assert "must be constructed directly from the workflow's" in result + assert "'run'" in result + + +@pytest.mark.asyncio +async def test_truncate_pubsub(client: Client) -> None: + """PubSub.truncate discards prefix and adjusts base_offset.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 5, + id=f"pubsub-truncate-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Verify all 5 items + items = await collect_items(client, handle, None, 0, 5) + assert len(items) == 5 + + # Truncate up to offset 3 (discard items 0, 1, 2). The update + # returns after the handler completes. + await handle.execute_update("truncate", 3) + + # Offset should still be 5 (truncation moves base_offset, not tail) + pubsub_client = PubSubClient(handle) + offset = await pubsub_client.get_offset() + assert offset == 5 + + # Reading from offset 3 should work (items 3, 4) + items_after = await collect_items(client, handle, None, 3, 2) + assert len(items_after) == 2 + assert items_after[0].data == b"item-3" + assert items_after[1].data == b"item-4" + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_ttl_pruning_in_get_pubsub_state(client: Client) -> None: + """PubSub.get_state prunes publishers whose last-seen time exceeds the + TTL while retaining newer publishers. The log itself is unaffected. + + Uses a wall-clock gap between publishes so that workflow.time() + advances between the two publishers' tasks. workflow.time() can't be + cleanly injected from outside, so a short real sleep is the mechanism. + """ + async with new_worker( + client, + TTLTestWorkflow, + ) as worker: + handle = await client.start_workflow( + TTLTestWorkflow.run, + id=f"pubsub-ttl-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # pub-old arrives first. + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"old"))], + publisher_id="pub-old", + sequence=1, + ), + ) + + # Sanity: pub-old is recorded (generous TTL retains it). + state_before = await handle.query(TTLTestWorkflow.get_state_with_ttl, 9999.0) + assert "pub-old" in state_before.publisher_sequences + + # Let workflow.time() advance by real wall-clock time. Use a + # generous gap (1.0s) relative to the TTL (0.5s) so the test + # tolerates CI scheduling delays — pub-old must be >=0.5s past, + # pub-new must be <0.5s past, at the moment of the query. + await asyncio.sleep(1.0) + + # pub-new arrives after the gap. + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"new"))], + publisher_id="pub-new", + sequence=1, + ), + ) + + # TTL=0.5s prunes pub-old (~1.0s old) but keeps pub-new (~0s). + state = await handle.query(TTLTestWorkflow.get_state_with_ttl, 0.5) + assert "pub-old" not in state.publisher_sequences + assert "pub-new" in state.publisher_sequences + # Log contents are not touched by publisher pruning. + assert len(state.log) == 2 + + await handle.signal("close") + + +# --------------------------------------------------------------------------- +# Truncate and TTL test workflows +# --------------------------------------------------------------------------- + + +@workflow.defn +class TruncateWorkflow: + """Test scaffolding that exposes PubSub.truncate via a user-authored + update. + + The contrib module does not define a built-in external truncate API — + truncation is a workflow-internal decision (typically driven by + consumer progress or a retention policy). Workflows that want external + control wire up their own signal or update. We use an update here so + callers get explicit completion (signals are fire-and-forget). + + The ``truncate`` update is ``async`` and opens with + ``await asyncio.sleep(0)`` — the documented recipe from the + contrib/pubsub README for sync-shaped handlers that read ``PubSub`` + state. The yield lets any buffered ``__temporal_pubsub_publish`` signal in + the same activation apply before the handler inspects ``self._log``. + This keeps the test workflow aligned with the pattern users are + directed to follow. + + ``prepub_count`` seeds the log with N byte-payload items during + ``@workflow.init`` as test convenience, so the error-path tests + have deterministic log content without an extra round trip to + publish from the client. + """ + + @workflow.init + def __init__(self, prepub_count: int = 0) -> None: + self.pubsub = PubSub() + self._closed = False + for i in range(prepub_count): + self.pubsub.publish("events", f"item-{i}".encode()) + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.update + async def truncate(self, up_to_offset: int) -> None: + # Recipe from README.md "Gotcha" section: yield once so any + # buffered __temporal_pubsub_publish in the same activation applies + # before we read self._log. asyncio.sleep(0) is a pure asyncio + # yield — no Temporal timer, no history event. + await asyncio.sleep(0) + self.pubsub.truncate(up_to_offset) + + @workflow.run + async def run(self, _prepub_count: int = 0) -> None: + # _prepub_count is consumed in @workflow.init above. @workflow.run + # must accept the same positional args, but the names are free + # to differ. + del _prepub_count + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class TTLTestWorkflow: + """Workflow that exposes PubSub.get_state via query for TTL testing.""" + + @workflow.init + def __init__(self) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.query + def get_state_with_ttl(self, ttl_seconds: float) -> PubSubState: + # Query arg is passed as float because the default JSON payload + # converter does not serialize ``timedelta``; convert here. + return self.pubsub.get_state(publisher_ttl=timedelta(seconds=ttl_seconds)) + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +# --------------------------------------------------------------------------- +# Continue-as-new workflow and test +# --------------------------------------------------------------------------- + + +@dataclass +class CANWorkflowInputTyped: + """Uses proper typing.""" + + pubsub_state: PubSubState | None = None + + +@workflow.defn +class ContinueAsNewTypedWorkflow: + """CAN workflow using properly-typed pubsub_state.""" + + @workflow.init + def __init__(self, input: CANWorkflowInputTyped) -> None: + self.pubsub = PubSub(prior_state=input.pubsub_state) + self._should_continue = False + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.signal + def trigger_continue(self) -> None: + self._should_continue = True + + @workflow.query + def publisher_sequences(self) -> dict[str, int]: + return dict(self.pubsub._publisher_sequences) + + @workflow.run + async def run(self, _input: CANWorkflowInputTyped) -> None: + # _input is consumed in @workflow.init above. @workflow.run must + # accept the same positional args, but the names are free to differ. + del _input + while True: + await workflow.wait_condition(lambda: self._should_continue or self._closed) + if self._closed: + return + if self._should_continue: + self._should_continue = False + self.pubsub.drain() + await workflow.wait_condition(workflow.all_handlers_finished) + workflow.continue_as_new( + args=[ + CANWorkflowInputTyped( + pubsub_state=self.pubsub.get_state(), + ) + ] + ) + + +@pytest.mark.asyncio +async def test_continue_as_new_properly_typed(client: Client) -> None: + """CAN preserves the log, global offsets, AND publisher dedup state + when pubsub_state is properly typed as ``PubSubState | None``.""" + async with new_worker( + client, + ContinueAsNewTypedWorkflow, + ) as worker: + handle = await client.start_workflow( + ContinueAsNewTypedWorkflow.run, + CANWorkflowInputTyped(), + id=f"pubsub-can-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Publish 3 items with an explicit publisher_id/sequence so dedup + # state is seeded and we can verify it survives CAN. + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"item-0")), + PublishEntry(topic="events", data=_wire_bytes(b"item-1")), + PublishEntry(topic="events", data=_wire_bytes(b"item-2")), + ], + publisher_id="pub", + sequence=1, + ), + ) + + items_before = await collect_items(client, handle, None, 0, 3) + assert len(items_before) == 3 + + await handle.signal(ContinueAsNewTypedWorkflow.trigger_continue) + + new_handle = client.get_workflow_handle(handle.id) + await assert_eq_eventually( + True, + lambda: _is_different_run(handle, new_handle), + ) + + # Log contents and offsets preserved across CAN. + items_after = await collect_items(client, new_handle, None, 0, 3) + assert [i.data for i in items_after] == [b"item-0", b"item-1", b"item-2"] + assert [i.offset for i in items_after] == [0, 1, 2] + + # Dedup state preserved: the carried publisher_sequences dict has + # pub -> 1 after CAN. + seqs_after_can = await new_handle.query( + ContinueAsNewTypedWorkflow.publisher_sequences + ) + assert seqs_after_can == {"pub": 1} + + # Re-sending publisher_id="pub", sequence=1 must be rejected by + # dedup — both the log and the publisher_sequences entry stay put. + await new_handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"dup")), + ], + publisher_id="pub", + sequence=1, + ), + ) + seqs_after_dup = await new_handle.query( + ContinueAsNewTypedWorkflow.publisher_sequences + ) + assert seqs_after_dup == {"pub": 1} + + # A fresh sequence from the same publisher is accepted, advances + # publisher_sequences to 2, and the new item gets offset 3. + await new_handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"item-3")), + ], + publisher_id="pub", + sequence=2, + ), + ) + seqs_after_accept = await new_handle.query( + ContinueAsNewTypedWorkflow.publisher_sequences + ) + assert seqs_after_accept == {"pub": 2} + items_all = await collect_items(client, new_handle, None, 0, 4) + assert [i.data for i in items_all] == [ + b"item-0", + b"item-1", + b"item-2", + b"item-3", + ] + assert items_all[3].offset == 3 + + await new_handle.signal(ContinueAsNewTypedWorkflow.close) + + +@workflow.defn +class ContinueAsNewHelperWorkflow: + """CAN workflow that uses the packaged ``PubSub.continue_as_new`` helper.""" + + @workflow.init + def __init__(self, input: CANWorkflowInputTyped) -> None: + self.pubsub = PubSub(prior_state=input.pubsub_state) + self._should_continue = False + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.signal + def trigger_continue(self) -> None: + self._should_continue = True + + @workflow.run + async def run(self, _input: CANWorkflowInputTyped) -> None: + del _input + while True: + await workflow.wait_condition(lambda: self._should_continue or self._closed) + if self._closed: + return + if self._should_continue: + self._should_continue = False + await self.pubsub.continue_as_new( + lambda state: [CANWorkflowInputTyped(pubsub_state=state)], + ) + + +@pytest.mark.asyncio +async def test_continue_as_new_helper(client: Client) -> None: + """The ``PubSub.continue_as_new`` helper preserves log and dedup state + just like the explicit drain/wait/CAN recipe.""" + async with new_worker( + client, + ContinueAsNewHelperWorkflow, + ) as worker: + handle = await client.start_workflow( + ContinueAsNewHelperWorkflow.run, + CANWorkflowInputTyped(), + id=f"pubsub-can-helper-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"item-0")), + PublishEntry(topic="events", data=_wire_bytes(b"item-1")), + ], + publisher_id="pub", + sequence=1, + ), + ) + + items_before = await collect_items(client, handle, None, 0, 2) + assert [i.data for i in items_before] == [b"item-0", b"item-1"] + + await handle.signal(ContinueAsNewHelperWorkflow.trigger_continue) + + new_handle = client.get_workflow_handle(handle.id) + await assert_eq_eventually( + True, + lambda: _is_different_run(handle, new_handle), + ) + + items_after = await collect_items(client, new_handle, None, 0, 2) + assert [i.data for i in items_after] == [b"item-0", b"item-1"] + assert [i.offset for i in items_after] == [0, 1] + + await new_handle.signal(ContinueAsNewHelperWorkflow.close) + + +# --------------------------------------------------------------------------- +# Cross-workflow pub/sub (Scenario 1) +# --------------------------------------------------------------------------- + + +@dataclass +class CrossWorkflowInput: + broker_workflow_id: str + expected_count: int + + +@workflow.defn +class BrokerWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.pubsub.publish("events", f"broker-{i}".encode()) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class SubscriberWorkflow: + @workflow.run + async def run(self, input: CrossWorkflowInput) -> list[str]: + return await workflow.execute_activity( + "subscribe_to_broker", + input, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + + +@activity.defn(name="subscribe_to_broker") +async def subscribe_to_broker(input: CrossWorkflowInput) -> list[str]: + client = PubSubClient.create( + client=activity.client(), + workflow_id=input.broker_workflow_id, + ) + items: list[str] = [] + async with _async_timeout(15.0): + async for item in client.subscribe( + topics=["events"], + from_offset=0, + poll_cooldown=timedelta(0), + result_type=bytes, + ): + items.append(item.data.decode()) + activity.heartbeat() + if len(items) >= input.expected_count: + break + return items + + +@pytest.mark.asyncio +async def test_cross_workflow_pubsub(client: Client) -> None: + """Workflow B's activity subscribes to events published by Workflow A.""" + count = 5 + task_queue = str(uuid.uuid4()) + + async with new_worker( + client, + BrokerWorkflow, + SubscriberWorkflow, + activities=[subscribe_to_broker], + task_queue=task_queue, + ): + broker_id = f"pubsub-broker-{uuid.uuid4()}" + broker_handle = await client.start_workflow( + BrokerWorkflow.run, + count, + id=broker_id, + task_queue=task_queue, + ) + + sub_handle = await client.start_workflow( + SubscriberWorkflow.run, + CrossWorkflowInput( + broker_workflow_id=broker_id, + expected_count=count, + ), + id=f"pubsub-subscriber-{uuid.uuid4()}", + task_queue=task_queue, + ) + + result = await sub_handle.result() + assert result == [f"broker-{i}" for i in range(count)] + + # Also verify external subscription still works + external_items = await collect_items( + client, broker_handle, ["events"], 0, count + ) + assert len(external_items) == count + + await broker_handle.signal(BrokerWorkflow.close) + + +# --------------------------------------------------------------------------- +# Cross-namespace pub/sub via Nexus (Scenario 2) +# --------------------------------------------------------------------------- + + +@dataclass +class StartBrokerInput: + count: int + broker_id: str + + +@dataclass +class NexusCallerInput: + count: int + broker_id: str + endpoint: str + + +@workflow.defn +class NexusBrokerWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.pubsub = PubSub() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> str: + for i in range(count): + self.pubsub.publish("events", f"nexus-{i}".encode()) + await workflow.wait_condition(lambda: self._closed) + return "done" + + +@nexusrpc.service +class PubSubNexusService: + start_broker: nexusrpc.Operation[StartBrokerInput, str] + + +@nexusrpc.handler.service_handler(service=PubSubNexusService) +class PubSubNexusHandler: + @workflow_run_operation + async def start_broker( + self, ctx: WorkflowRunOperationContext, input: StartBrokerInput + ) -> nexus.WorkflowHandle[str]: + return await ctx.start_workflow( + NexusBrokerWorkflow.run, + input.count, + id=input.broker_id, + ) + + +@workflow.defn +class NexusCallerWorkflow: + @workflow.run + async def run(self, input: NexusCallerInput) -> str: + nc = workflow.create_nexus_client( + service=PubSubNexusService, + endpoint=input.endpoint, + ) + return await nc.execute_operation( + PubSubNexusService.start_broker, + StartBrokerInput(count=input.count, broker_id=input.broker_id), + ) + + +async def create_cross_namespace_endpoint( + client: Client, + endpoint_name: str, + target_namespace: str, + task_queue: str, +) -> None: + await client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=endpoint_name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=target_namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + + +@pytest.mark.asyncio +async def test_poll_more_ready_when_response_exceeds_size_limit( + client: Client, +) -> None: + """Poll response sets more_ready=True when items exceed ~1MB wire size.""" + async with new_worker( + client, + BasicPubSubWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-more-ready-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Publish items that total well over 1MB in the poll response. + # Send in separate signals to stay under the RPC size limit. + # Each item is ~200KB; 8 items = ~1.6MB wire (base64 inflates ~33%). + chunk = b"x" * 200_000 + for _ in range(8): + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="big", data=_wire_bytes(chunk))] + ), + ) + + # First poll from offset 0 — should get some items but not all. + # (The update acts as a barrier for all prior publish signals.) + result1: PollResult = await handle.execute_update( + "__temporal_pubsub_poll", + PollInput(topics=[], from_offset=0), + result_type=PollResult, + ) + assert result1.more_ready is True + assert len(result1.items) < 8 + assert result1.next_offset < 8 + + # Continue polling until we have all items + all_items = list(result1.items) + offset = result1.next_offset + last_result: PollResult = result1 + while len(all_items) < 8: + last_result = await handle.execute_update( + "__temporal_pubsub_poll", + PollInput(topics=[], from_offset=offset), + result_type=PollResult, + ) + all_items.extend(last_result.items) + offset = last_result.next_offset + assert len(all_items) == 8 + # The final poll that drained the log should set more_ready=False + assert last_result.more_ready is False + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_subscribe_iterates_through_more_ready(client: Client) -> None: + """Subscriber correctly yields all items when polls are size-truncated.""" + async with new_worker( + client, + BasicPubSubWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicPubSubWorkflow.run, + id=f"pubsub-more-ready-iter-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Publish 8 x 200KB items (~2MB+ wire, exceeds 1MB cap) + chunk = b"x" * 200_000 + for _ in range(8): + await handle.signal( + "__temporal_pubsub_publish", + PublishInput( + items=[PublishEntry(topic="big", data=_wire_bytes(chunk))] + ), + ) + + # subscribe() should seamlessly iterate through all 8 items + items = await collect_items(client, handle, None, 0, 8, timeout=10.0) + assert len(items) == 8 + for item in items: + assert item.data == chunk + + await handle.signal(BasicPubSubWorkflow.close) + + +@pytest.mark.asyncio +async def test_cross_namespace_nexus_pubsub( + client: Client, env: WorkflowEnvironment +) -> None: + """Nexus operation starts a pub/sub broker in another namespace; test subscribes.""" + if env.supports_time_skipping: + pytest.skip("Nexus not supported with time-skipping server") + + count = 5 + handler_ns = f"handler-ns-{uuid.uuid4().hex[:8]}" + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + broker_id = f"nexus-broker-{uuid.uuid4()}" + + # Register the handler namespace with the dev server + await client.workflow_service.register_namespace( + temporalio.api.workflowservice.v1.RegisterNamespaceRequest( + namespace=handler_ns, + workflow_execution_retention_period=google.protobuf.duration_pb2.Duration( + seconds=86400, + ), + ) + ) + + handler_client = await Client.connect( + client.service_client.config.target_host, + namespace=handler_ns, + ) + + # Create endpoint targeting the handler namespace + await create_cross_namespace_endpoint( + client, + endpoint_name, + target_namespace=handler_ns, + task_queue=task_queue, + ) + + # Handler worker in handler namespace + async with Worker( + handler_client, + task_queue=task_queue, + workflows=[NexusBrokerWorkflow], + nexus_service_handlers=[PubSubNexusHandler()], + ): + # Caller worker in default namespace + caller_tq = str(uuid.uuid4()) + async with new_worker( + client, + NexusCallerWorkflow, + task_queue=caller_tq, + ): + # Start caller — invokes Nexus op which starts broker in handler ns + caller_handle = await client.start_workflow( + NexusCallerWorkflow.run, + NexusCallerInput( + count=count, + broker_id=broker_id, + endpoint=endpoint_name, + ), + id=f"nexus-caller-{uuid.uuid4()}", + task_queue=caller_tq, + ) + + # Wait for the broker workflow to be started by the Nexus operation + broker_handle = handler_client.get_workflow_handle(broker_id) + + async def broker_started() -> bool: + try: + await broker_handle.describe() + return True + except Exception: + return False + + await assert_eq_eventually( + True, broker_started, timeout=timedelta(seconds=15) + ) + + # Subscribe to broker events from the handler namespace + items = await collect_items( + handler_client, broker_handle, ["events"], 0, count + ) + assert len(items) == count + for i in range(count): + assert items[i].topic == "events" + assert items[i].data == f"nexus-{i}".encode() + + # Clean up — signal broker to close so caller can complete + await broker_handle.signal("close") + result = await caller_handle.result() + assert result == "done" diff --git a/uv.lock b/uv.lock index bdc25a507..bb75e49dc 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,13 @@ resolution-markers = [ "python_full_version < '3.11'", ] +[options] +exclude-newer = "2026-04-18T05:37:33.920196Z" +exclude-newer-span = "P1W" + +[options.exclude-newer-package] +openai-agents = false + [[package]] name = "aioboto3" version = "15.5.0" @@ -1812,7 +1819,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/3f/9859f655d11901e7b2996c6e3d33e0caa9a1d4572c3bc61ed0faa64b2f4c/greenlet-3.3.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9bc885b89709d901859cf95179ec9f6bb67a3d2bb1f0e88456461bd4b7f8fd0d", size = 277747, upload-time = "2026-02-20T20:16:21.325Z" }, { url = "https://files.pythonhosted.org/packages/fb/07/cb284a8b5c6498dbd7cba35d31380bb123d7dceaa7907f606c8ff5993cbf/greenlet-3.3.2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b568183cf65b94919be4438dc28416b234b678c608cafac8874dfeeb2a9bbe13", size = 579202, upload-time = "2026-02-20T20:47:28.955Z" }, { url = "https://files.pythonhosted.org/packages/ed/45/67922992b3a152f726163b19f890a85129a992f39607a2a53155de3448b8/greenlet-3.3.2-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:527fec58dc9f90efd594b9b700662ed3fb2493c2122067ac9c740d98080a620e", size = 590620, upload-time = "2026-02-20T20:55:55.581Z" }, - { url = "https://files.pythonhosted.org/packages/03/5f/6e2a7d80c353587751ef3d44bb947f0565ec008a2e0927821c007e96d3a7/greenlet-3.3.2-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:508c7f01f1791fbc8e011bd508f6794cb95397fdb198a46cb6635eb5b78d85a7", size = 602132, upload-time = "2026-02-20T21:02:43.261Z" }, { url = "https://files.pythonhosted.org/packages/ad/55/9f1ebb5a825215fadcc0f7d5073f6e79e3007e3282b14b22d6aba7ca6cb8/greenlet-3.3.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad0c8917dd42a819fe77e6bdfcb84e3379c0de956469301d9fd36427a1ca501f", size = 591729, upload-time = "2026-02-20T20:20:58.395Z" }, { url = "https://files.pythonhosted.org/packages/24/b4/21f5455773d37f94b866eb3cf5caed88d6cea6dd2c6e1f9c34f463cba3ec/greenlet-3.3.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:97245cc10e5515dbc8c3104b2928f7f02b6813002770cfaffaf9a6e0fc2b94ef", size = 1551946, upload-time = "2026-02-20T20:49:31.102Z" }, { url = "https://files.pythonhosted.org/packages/00/68/91f061a926abead128fe1a87f0b453ccf07368666bd59ffa46016627a930/greenlet-3.3.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8c1fdd7d1b309ff0da81d60a9688a8bd044ac4e18b250320a96fc68d31c209ca", size = 1618494, upload-time = "2026-02-20T20:21:06.541Z" }, @@ -1820,7 +1826,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f3/47/16400cb42d18d7a6bb46f0626852c1718612e35dcb0dffa16bbaffdf5dd2/greenlet-3.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:c56692189a7d1c7606cb794be0a8381470d95c57ce5be03fb3d0ef57c7853b86", size = 278890, upload-time = "2026-02-20T20:19:39.263Z" }, { url = "https://files.pythonhosted.org/packages/a3/90/42762b77a5b6aa96cd8c0e80612663d39211e8ae8a6cd47c7f1249a66262/greenlet-3.3.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ebd458fa8285960f382841da585e02201b53a5ec2bac6b156fc623b5ce4499f", size = 581120, upload-time = "2026-02-20T20:47:30.161Z" }, { url = "https://files.pythonhosted.org/packages/bf/6f/f3d64f4fa0a9c7b5c5b3c810ff1df614540d5aa7d519261b53fba55d4df9/greenlet-3.3.2-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a443358b33c4ec7b05b79a7c8b466f5d275025e750298be7340f8fc63dff2a55", size = 594363, upload-time = "2026-02-20T20:55:56.965Z" }, - { url = "https://files.pythonhosted.org/packages/9c/8b/1430a04657735a3f23116c2e0d5eb10220928846e4537a938a41b350bed6/greenlet-3.3.2-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4375a58e49522698d3e70cc0b801c19433021b5c37686f7ce9c65b0d5c8677d2", size = 605046, upload-time = "2026-02-20T21:02:45.234Z" }, { url = "https://files.pythonhosted.org/packages/72/83/3e06a52aca8128bdd4dcd67e932b809e76a96ab8c232a8b025b2850264c5/greenlet-3.3.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e2cd90d413acbf5e77ae41e5d3c9b3ac1d011a756d7284d7f3f2b806bbd6358", size = 594156, upload-time = "2026-02-20T20:20:59.955Z" }, { url = "https://files.pythonhosted.org/packages/70/79/0de5e62b873e08fe3cef7dbe84e5c4bc0e8ed0c7ff131bccb8405cd107c8/greenlet-3.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:442b6057453c8cb29b4fb36a2ac689382fc71112273726e2423f7f17dc73bf99", size = 1554649, upload-time = "2026-02-20T20:49:32.293Z" }, { url = "https://files.pythonhosted.org/packages/5a/00/32d30dee8389dc36d42170a9c66217757289e2afb0de59a3565260f38373/greenlet-3.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:45abe8eb6339518180d5a7fa47fa01945414d7cca5ecb745346fc6a87d2750be", size = 1619472, upload-time = "2026-02-20T20:21:07.966Z" }, @@ -1829,7 +1834,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd", size = 280358, upload-time = "2026-02-20T20:17:43.971Z" }, { url = "https://files.pythonhosted.org/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd", size = 601217, upload-time = "2026-02-20T20:47:31.462Z" }, { url = "https://files.pythonhosted.org/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac", size = 611792, upload-time = "2026-02-20T20:55:58.423Z" }, - { url = "https://files.pythonhosted.org/packages/5c/c5/cc09412a29e43406eba18d61c70baa936e299bc27e074e2be3806ed29098/greenlet-3.3.2-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ae9e21c84035c490506c17002f5c8ab25f980205c3e61ddb3a2a2a2e6c411fcb", size = 626250, upload-time = "2026-02-20T21:02:46.596Z" }, { url = "https://files.pythonhosted.org/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070", size = 613875, upload-time = "2026-02-20T20:21:01.102Z" }, { url = "https://files.pythonhosted.org/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79", size = 1571467, upload-time = "2026-02-20T20:49:33.495Z" }, { url = "https://files.pythonhosted.org/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395", size = 1640001, upload-time = "2026-02-20T20:21:09.154Z" }, @@ -1838,7 +1842,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, - { url = "https://files.pythonhosted.org/packages/94/2b/4d012a69759ac9d77210b8bfb128bc621125f5b20fc398bce3940d036b1c/greenlet-3.3.2-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ccd21bb86944ca9be6d967cf7691e658e43417782bce90b5d2faeda0ff78a7dd", size = 628268, upload-time = "2026-02-20T21:02:48.024Z" }, { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, @@ -1847,7 +1850,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, - { url = "https://files.pythonhosted.org/packages/cd/ac/85804f74f1ccea31ba518dcc8ee6f14c79f73fe36fa1beba38930806df09/greenlet-3.3.2-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e3cb43ce200f59483eb82949bf1835a99cf43d7571e900d7c8d5c62cdf25d2f9", size = 675371, upload-time = "2026-02-20T21:02:49.664Z" }, { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, @@ -1856,7 +1858,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, - { url = "https://files.pythonhosted.org/packages/d1/67/8197b7e7e602150938049d8e7f30de1660cfb87e4c8ee349b42b67bdb2e1/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:59b3e2c40f6706b05a9cd299c836c6aa2378cabe25d021acd80f13abf81181cf", size = 666581, upload-time = "2026-02-20T21:02:51.526Z" }, { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" }, @@ -5056,6 +5057,7 @@ pydantic = [ [package.dev-dependencies] dev = [ + { name = "async-timeout", marker = "python_full_version < '3.11'" }, { name = "basedpyright" }, { name = "cibuildwheel" }, { name = "googleapis-common-protos" }, @@ -5118,6 +5120,7 @@ provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google [package.metadata.requires-dev] dev = [ + { name = "async-timeout", marker = "python_full_version < '3.11'", specifier = ">=4.0,<6" }, { name = "basedpyright", specifier = "==1.34.0" }, { name = "cibuildwheel", specifier = ">=2.22.0,<3" }, { name = "googleapis-common-protos", specifier = "==1.70.0" },