-
Notifications
You must be signed in to change notification settings - Fork 176
Add temporalio.contrib.pubsub module #1423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
33a2f15
e2712e2
17952ae
d1dfce7
f20ba36
70bf747
70898d0
5ff7e27
6fbb168
42b0df1
c87a65a
d5a23c3
3089b12
f06a53e
990a6a7
f2c6e55
a9abc20
20dafc0
5a8716c
eda55d5
7bc830a
475df95
90d753e
c76a774
97be29c
6f0f345
9d0a259
c4ec6e7
4945cbc
6d9ea42
7d42b29
436430c
c09ad49
e683c5c
3a71028
4ab7ce4
2fbe0d4
fdbb339
5a0796f
3541790
68ad53d
682c420
368d023
beacec9
56789ed
6193f80
4f9d669
e9d4e6b
75efe24
68c719e
72d296e
ef7e041
99a7a8a
4205242
dddbcef
47ee940
8a971d0
736b570
47106ad
885d0e8
2d76877
8e5c3e4
9274670
b11748b
48645d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,14 +1,31 @@ | ||||||
| import json | ||||||
| import logging | ||||||
| from collections.abc import AsyncGenerator, Callable | ||||||
| from datetime import timedelta | ||||||
| from datetime import datetime, timedelta, timezone | ||||||
|
|
||||||
| from google.adk.models import BaseLlm, LLMRegistry | ||||||
| from google.adk.models.llm_request import LlmRequest | ||||||
| from google.adk.models.llm_response import LlmResponse | ||||||
|
|
||||||
| import temporalio.workflow | ||||||
| from temporalio import activity, workflow | ||||||
| from temporalio.contrib.pubsub import PubSubClient | ||||||
| from temporalio.workflow import ActivityConfig | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
| EVENTS_TOPIC = "events" | ||||||
|
|
||||||
|
|
||||||
| def _make_event(event_type: str, **data: object) -> bytes: | ||||||
| return json.dumps( | ||||||
| { | ||||||
| "type": event_type, | ||||||
| "timestamp": datetime.now(timezone.utc).isoformat(), | ||||||
| "data": data, | ||||||
| } | ||||||
| ).encode() | ||||||
|
Comment on lines
+20
to
+27
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a bit of a tricky interface to publish to users that want to consume these events. Here, you have to read both this function def and all the call sites to see what shape the |
||||||
|
|
||||||
|
|
||||||
| @activity.defn | ||||||
| async def invoke_model(llm_request: LlmRequest) -> list[LlmResponse]: | ||||||
|
|
@@ -36,13 +53,78 @@ async def invoke_model(llm_request: LlmRequest) -> list[LlmResponse]: | |||||
| ] | ||||||
|
|
||||||
|
|
||||||
| @activity.defn | ||||||
| async def invoke_model_streaming(llm_request: LlmRequest) -> list[LlmResponse]: | ||||||
| """Streaming-aware model activity. | ||||||
|
|
||||||
| Calls the LLM with stream=True, publishes TEXT_DELTA events via | ||||||
| PubSubClient as tokens arrive, and returns the collected responses. | ||||||
|
|
||||||
| The PubSubClient auto-detects the activity context to find the parent | ||||||
| workflow for publishing. | ||||||
|
|
||||||
| Args: | ||||||
| llm_request: The LLM request containing model name and parameters. | ||||||
|
|
||||||
| Returns: | ||||||
| List of LLM responses from the model. | ||||||
| """ | ||||||
| if llm_request.model is None: | ||||||
| raise ValueError("No model name provided, could not create LLM.") | ||||||
|
|
||||||
| llm = LLMRegistry.new_llm(llm_request.model) | ||||||
| if not llm: | ||||||
| raise ValueError(f"Failed to create LLM for model: {llm_request.model}") | ||||||
|
|
||||||
| pubsub = PubSubClient.from_activity(batch_interval=0.1) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Include units? I suspect the typical use case will be on the order of milliseconds, not seconds.
Suggested change
|
||||||
| responses: list[LlmResponse] = [] | ||||||
| text_buffer = "" | ||||||
|
|
||||||
| async with pubsub: | ||||||
| pubsub.publish(EVENTS_TOPIC, _make_event("LLM_CALL_START"), force_flush=True) | ||||||
|
|
||||||
| async for response in llm.generate_content_async( | ||||||
| llm_request=llm_request, stream=True | ||||||
| ): | ||||||
| activity.heartbeat() | ||||||
| responses.append(response) | ||||||
|
|
||||||
| if response.content and response.content.parts: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the openai agents activity below you publish the full stream event, you don't unpack and inspect each event and publish specially named events. I think that's actually the better approach. I'd suggest just |
||||||
| for part in response.content.parts: | ||||||
| if part.text: | ||||||
| text_buffer += part.text | ||||||
| pubsub.publish( | ||||||
| EVENTS_TOPIC, | ||||||
| _make_event("TEXT_DELTA", delta=part.text), | ||||||
| ) | ||||||
| if part.function_call: | ||||||
| pubsub.publish( | ||||||
| EVENTS_TOPIC, | ||||||
| _make_event( | ||||||
| "TOOL_CALL_START", | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I absolutely agree that we should be emitting lifecycle events like this E.g. for this one in particular ( Of course, this is moot if you accept the suggestion in my comment immediately above this one. |
||||||
| tool_name=part.function_call.name, | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| if text_buffer: | ||||||
| pubsub.publish( | ||||||
| EVENTS_TOPIC, | ||||||
| _make_event("TEXT_COMPLETE", text=text_buffer), | ||||||
| force_flush=True, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need On a related note, I don't see any analogous |
||||||
| ) | ||||||
| pubsub.publish(EVENTS_TOPIC, _make_event("LLM_CALL_COMPLETE"), force_flush=True) | ||||||
|
|
||||||
| return responses | ||||||
|
|
||||||
|
|
||||||
| class TemporalModel(BaseLlm): | ||||||
| """A Temporal-based LLM model that executes model invocations as activities.""" | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| model_name: str, | ||||||
| activity_config: ActivityConfig | None = None, | ||||||
| streaming: bool = False, | ||||||
| *, | ||||||
| summary_fn: Callable[[LlmRequest], str | None] | None = None, | ||||||
| ) -> None: | ||||||
|
|
@@ -51,6 +133,9 @@ def __init__( | |||||
| Args: | ||||||
| model_name: The name of the model to use. | ||||||
| activity_config: Configuration options for the activity execution. | ||||||
| streaming: When True, the model activity uses the streaming LLM | ||||||
| endpoint and publishes token events via PubSubClient. The | ||||||
| workflow is unaffected -- it still receives complete responses. | ||||||
| summary_fn: Optional callable that receives the LlmRequest and | ||||||
| returns a summary string (or None) for the activity. Must be | ||||||
| deterministic as it is called during workflow execution. If | ||||||
|
|
@@ -62,6 +147,7 @@ def __init__( | |||||
| """ | ||||||
| super().__init__(model=model_name) | ||||||
| self._model_name = model_name | ||||||
| self._streaming = streaming | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not always clear that this variable is a bool, renaming could help
Suggested change
|
||||||
| self._summary_fn = summary_fn | ||||||
| self._activity_config = ActivityConfig( | ||||||
| start_to_close_timeout=timedelta(seconds=60) | ||||||
|
|
@@ -80,7 +166,8 @@ async def generate_content_async( | |||||
|
|
||||||
| Args: | ||||||
| llm_request: The LLM request containing model parameters and content. | ||||||
| stream: Whether to stream the response (currently ignored). | ||||||
| stream: Whether to stream the response (currently ignored; use the | ||||||
| ``streaming`` constructor parameter instead). | ||||||
|
Comment on lines
+169
to
+170
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this could be safely honored since it's so seamless to swap between the underlying activities
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't actually do the thing you would expect though, as it doesn't stream back. |
||||||
|
|
||||||
| Yields: | ||||||
| The responses from the model. | ||||||
|
|
@@ -103,8 +190,9 @@ async def generate_content_async( | |||||
| agent_name = llm_request.config.labels.get("adk_agent_name") | ||||||
| if agent_name: | ||||||
| config["summary"] = agent_name | ||||||
| activity_fn = invoke_model_streaming if self._streaming else invoke_model | ||||||
| responses = await workflow.execute_activity( | ||||||
| invoke_model, | ||||||
| activity_fn, | ||||||
| args=[llm_request], | ||||||
| **config, | ||||||
| ) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see these
logger = ...lines added in several files, but I don't see them used. Were these loggers for your local debugging?