From 3788a6b7cdf9c45179bf080a7f86dc1f52e0700b Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 23 Mar 2026 10:36:26 -0400 Subject: [PATCH 01/12] Add ADK model handler --- .../ml/inference/agent_development_kit.py | 284 ++++++++++++++ .../inference/agent_development_kit_test.py | 356 ++++++++++++++++++ 2 files changed, 640 insertions(+) create mode 100644 sdks/python/apache_beam/ml/inference/agent_development_kit.py create mode 100644 sdks/python/apache_beam/ml/inference/agent_development_kit_test.py diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py new file mode 100644 index 000000000000..59dc0cfb2e08 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -0,0 +1,284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""ModelHandler for running agents built with the Google Agent Development Kit. + +This module provides :class:`ADKAgentModelHandler`, a Beam +:class:`~apache_beam.ml.inference.base.ModelHandler` that wraps an ADK +:class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the +:class:`~apache_beam.ml.inference.base.RunInference` transform. + +**NOTE:** This API and its implementation are under development and do not +provide backward compatibility guarantees. + +Typical usage:: + + import apache_beam as beam + from apache_beam.ml.inference.base import RunInference + from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler + from google.adk.agents import LlmAgent + + agent = LlmAgent( + name="my_agent", + model="gemini-2.0-flash", + instruction="You are a helpful assistant.", + ) + + with beam.Pipeline() as p: + results = ( + p + | beam.Create(["What is the capital of France?"]) + | RunInference(ADKAgentModelHandler(agent=agent)) + ) + +If your agent contains state that is not picklable (e.g. tool closures that +capture unpicklable objects), pass a zero-arg factory callable instead:: + + handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...)) + +""" + +import asyncio +import logging +import uuid +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Sequence +from typing import Any +from typing import Optional +from typing import Union + +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult + +try: + from google.adk.agents import Agent + from google.adk.runners import Runner + from google.adk.sessions import BaseSessionService + from google.adk.sessions import InMemorySessionService + from google.genai import types as genai_types + ADK_AVAILABLE = True +except ImportError: + ADK_AVAILABLE = False + +LOGGER = logging.getLogger("ADKAgentModelHandler") + +# Type alias for an agent or factory that produces one +_AgentOrFactory = Union["Agent", Callable[[], "Agent"]] + + +class ADKAgentModelHandler(ModelHandler[Union[str, Any], PredictionResult, + "Runner"]): + """ModelHandler for running ADK agents with the Beam RunInference transform. + + Accepts either a fully constructed :class:`google.adk.agents.Agent` or a + zero-arg factory callable that produces one. The factory form is useful when + the agent contains state that is not picklable and therefore cannot be + serialized alongside the pipeline graph. + + Each call to :meth:`run_inference` invokes the agent once per element in the + batch. By default every invocation uses a fresh, isolated session (stateless). + Stateful multi-turn conversations can be achieved by passing a ``session_id`` + key inside ``inference_args``; elements sharing the same ``session_id`` will + continue the same conversation history. + + Args: + agent: A pre-constructed :class:`~google.adk.agents.Agent` instance, or a + zero-arg callable that returns one. The callable form defers agent + construction to worker ``load_model`` time, which is useful when the + agent cannot be serialized. + app_name: The ADK application name used to namespace sessions. Defaults to + ``"beam_inference"``. + session_service_factory: Optional zero-arg callable returning a + :class:`~google.adk.sessions.BaseSessionService`. When ``None``, an + :class:`~google.adk.sessions.InMemorySessionService` is created + automatically. + min_batch_size: Optional minimum batch size. + max_batch_size: Optional maximum batch size. + max_batch_duration_secs: Optional maximum time to buffer a batch before + emitting; used in streaming contexts. + max_batch_weight: Optional maximum total weight of a batch. + element_size_fn: Optional function that returns the size (weight) of an + element. + """ + + def __init__( + self, + agent: _AgentOrFactory, + app_name: str = "beam_inference", + session_service_factory: Optional[Callable[[], "BaseSessionService"]] = + None, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, + **kwargs): + if not ADK_AVAILABLE: + raise ImportError( + "google-adk is required to use ADKAgentModelHandler. " + "Install it with: pip install google-adk") + + if agent is None: + raise ValueError("'agent' must be an Agent instance or a callable.") + + self._agent_or_factory = agent + self._app_name = app_name + self._session_service_factory = session_service_factory + + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + **kwargs) + + def load_model(self) -> "Runner": + """Instantiates the ADK Runner on the worker. + + Resolves the agent (calling the factory if a callable was provided), then + creates a :class:`~google.adk.runners.Runner` backed by the configured + session service. + + Returns: + A fully initialised :class:`~google.adk.runners.Runner`. + """ + if callable(self._agent_or_factory) and not isinstance( + self._agent_or_factory, Agent): + agent = self._agent_or_factory() + else: + agent = self._agent_or_factory + + if self._session_service_factory is not None: + session_service = self._session_service_factory() + else: + session_service = InMemorySessionService() + + runner = Runner( + agent=agent, + app_name=self._app_name, + session_service=session_service, + ) + LOGGER.info( + "Loaded ADK Runner for agent '%s' (app_name='%s')", + agent.name, + self._app_name, + ) + return runner + + def run_inference( + self, + batch: Sequence[Union[str, Any]], + model: "Runner", + inference_args: Optional[dict[str, Any]] = None, + ) -> Iterable[PredictionResult]: + """Runs the ADK agent on each element in the batch. + + Each element is sent to the agent as a new user turn. The final response + text from the agent is returned as the ``inference`` field of a + :class:`~apache_beam.ml.inference.base.PredictionResult`. + + Args: + batch: A sequence of inputs, each of which is either a ``str`` (the user + message text) or a :class:`google.genai.types.Content` object (for + richer multi-part messages). + model: The :class:`~google.adk.runners.Runner` returned by + :meth:`load_model`. + inference_args: Optional dict of extra arguments. Supported keys: + + - ``"session_id"`` (:class:`str`): If supplied, all elements in this + batch share this session ID, enabling stateful multi-turn + conversations. If omitted, each element receives a unique auto- + generated session ID. + - ``"user_id"`` (:class:`str`): The user identifier to pass to the + runner. Defaults to ``"beam_user"``. + + Returns: + An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`, + one per input element. + """ + if inference_args is None: + inference_args = {} + + user_id: str = inference_args.get("user_id", "beam_user") + + results = [] + for element in batch: + session_id: str = inference_args.get("session_id", str(uuid.uuid4())) + + # Ensure a session exists for this invocation + model.session_service.create_session( + app_name=self._app_name, + user_id=user_id, + session_id=session_id, + ) + + # Wrap plain strings in a Content object + if isinstance(element, str): + message = genai_types.Content( + role="user", parts=[genai_types.Part(text=element)]) + else: + # Assume the caller has already constructed a types.Content object + message = element + + response_text = asyncio.run( + self._invoke_agent(model, user_id, session_id, message)) + + results.append( + PredictionResult( + example=element, + inference=response_text, + model_id=model.agent.name, + )) + + return results + + @staticmethod + async def _invoke_agent( + runner: "Runner", + user_id: str, + session_id: str, + message: Any, + ) -> Optional[str]: + """Drives the ADK event loop and returns the final response text. + + Args: + runner: The ADK Runner to invoke. + user_id: The user ID for this invocation. + session_id: The session ID for this invocation. + message: The :class:`google.genai.types.Content` to send. + + Returns: + The text of the agent's final response, or ``None`` if the agent + produced no final text response. + """ + final_text: Optional[str] = None + async for event in runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=message, + ): + if event.is_final_response(): + if event.content and event.content.parts: + final_text = event.content.parts[0].text + break + return final_text + + def get_metrics_namespace(self) -> str: + return "ADKAgentModelHandler" diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py new file mode 100644 index 000000000000..7bd77c52ff12 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -0,0 +1,356 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import asyncio +import unittest +from unittest import mock + +try: + from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler + from apache_beam.ml.inference.base import PredictionResult +except ImportError: + raise unittest.SkipTest('google-adk dependencies are not installed') + + +def _make_mock_agent(name: str = "test_agent") -> mock.MagicMock: + """Returns a mock that quacks like a google.adk.agents.Agent.""" + agent = mock.MagicMock() + agent.name = name + return agent + + +def _make_mock_runner( + agent: mock.MagicMock, + final_text: str = "Hello from agent", +) -> mock.MagicMock: + """Returns a mock Runner whose run_async yields one final-response event.""" + # Build a mock event that looks like a final response + part = mock.MagicMock() + part.text = final_text + + content = mock.MagicMock() + content.parts = [part] + + event = mock.MagicMock() + event.is_final_response.return_value = True + event.content = content + + async def _async_gen(*args, **kwargs): + yield event + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + return runner + + +# --------------------------------------------------------------------------- +# Helper: patch ADK imports inside the module under test so tests work even +# when google-adk is installed (avoids constructing real ADK objects). +# --------------------------------------------------------------------------- +_MODULE = "apache_beam.ml.inference.agent_development_kit" + + +class TestADKAgentModelHandlerInit(unittest.TestCase): + """Tests for __init__ argument validation.""" + + def test_raises_if_agent_is_none(self): + with self.assertRaises((ValueError, TypeError)): + ADKAgentModelHandler(agent=None) # type: ignore[arg-type] + + def test_accepts_agent_object(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent) + self.assertEqual(handler._agent_or_factory, agent) + + def test_accepts_agent_factory_callable(self): + agent = _make_mock_agent() + factory = lambda: agent + handler = ADKAgentModelHandler(agent=factory) + self.assertTrue(callable(handler._agent_or_factory)) + + def test_default_app_name(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent) + self.assertEqual(handler._app_name, "beam_inference") + + def test_custom_app_name(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent, app_name="my_app") + self.assertEqual(handler._app_name, "my_app") + + def test_metrics_namespace(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent) + self.assertEqual(handler.get_metrics_namespace(), "ADKAgentModelHandler") + + +class TestLoadModel(unittest.TestCase): + """Tests for load_model / Runner construction.""" + + @mock.patch(f"{_MODULE}.Runner") + @mock.patch(f"{_MODULE}.InMemorySessionService") + def test_load_model_with_agent_object( + self, mock_session_cls, mock_runner_cls): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent, app_name="test_app") + + handler.load_model() + + mock_session_cls.assert_called_once() + mock_runner_cls.assert_called_once_with( + agent=agent, + app_name="test_app", + session_service=mock_session_cls.return_value, + ) + + @mock.patch(f"{_MODULE}.Runner") + @mock.patch(f"{_MODULE}.InMemorySessionService") + def test_load_model_calls_factory(self, mock_session_cls, mock_runner_cls): + agent = _make_mock_agent() + factory = mock.MagicMock(return_value=agent) + + handler = ADKAgentModelHandler(agent=factory) + handler.load_model() + + factory.assert_called_once() + mock_runner_cls.assert_called_once_with( + agent=agent, + app_name="beam_inference", + session_service=mock_session_cls.return_value, + ) + + @mock.patch(f"{_MODULE}.Runner") + def test_load_model_uses_custom_session_service(self, mock_runner_cls): + agent = _make_mock_agent() + custom_session_service = mock.MagicMock() + session_factory = mock.MagicMock(return_value=custom_session_service) + + handler = ADKAgentModelHandler( + agent=agent, session_service_factory=session_factory) + handler.load_model() + + session_factory.assert_called_once() + mock_runner_cls.assert_called_once_with( + agent=agent, + app_name="beam_inference", + session_service=custom_session_service, + ) + + +class TestRunInference(unittest.TestCase): + """Tests for run_inference output and batching.""" + + def test_string_input_returns_prediction_result(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="Paris") + + handler = ADKAgentModelHandler(agent=agent) + results = list( + handler.run_inference( + batch=["What is the capital of France?"], model=runner)) + + self.assertEqual(len(results), 1) + pr = results[0] + self.assertIsInstance(pr, PredictionResult) + self.assertEqual(pr.example, "What is the capital of France?") + self.assertEqual(pr.inference, "Paris") + self.assertEqual(pr.model_id, "test_agent") + + def test_batch_of_strings(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="answer") + + handler = ADKAgentModelHandler(agent=agent) + results = list( + handler.run_inference(batch=["q1", "q2", "q3"], model=runner)) + + self.assertEqual(len(results), 3) + self.assertEqual([r.example for r in results], ["q1", "q2", "q3"]) + + def test_content_object_input(self): + """Non-string inputs (types.Content) are passed through unchanged.""" + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="Berlin") + + content_input = mock.MagicMock() # simulates types.Content + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=[content_input], model=runner)) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0].example, content_input) + self.assertEqual(results[0].inference, "Berlin") + + def test_none_inference_args_uses_defaults(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + results = list( + handler.run_inference( + batch=["hello"], model=runner, inference_args=None)) + self.assertEqual(len(results), 1) + + def test_custom_user_id_passed_to_runner(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + handler.run_inference( + batch=["hi"], + model=runner, + inference_args={"user_id": "custom_user"}, + ) + + call_kwargs = runner.run_async.call_args[1] + self.assertEqual(call_kwargs["user_id"], "custom_user") + + +class TestSessionManagement(unittest.TestCase): + """Tests for session creation and session_id handling.""" + + def test_each_element_gets_unique_session_by_default(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + handler.run_inference(batch=["a", "b", "c"], model=runner) + + # create_session should have been called 3 times with distinct session IDs + calls = runner.session_service.create_session.call_args_list + self.assertEqual(len(calls), 3) + session_ids = [c[1]["session_id"] for c in calls] + self.assertEqual(len(set(session_ids)), 3, "Expected unique session IDs") + + def test_shared_session_id_from_inference_args(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + handler.run_inference( + batch=["turn1", "turn2"], + model=runner, + inference_args={"session_id": "my-session"}, + ) + + calls = runner.session_service.create_session.call_args_list + session_ids = [c[1]["session_id"] for c in calls] + self.assertTrue( + all(sid == "my-session" for sid in session_ids), + "All elements should share the provided session_id", + ) + + def test_session_created_with_correct_app_name(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent, app_name="my_app") + handler.run_inference(batch=["hello"], model=runner) + + call_kwargs = runner.session_service.create_session.call_args[1] + self.assertEqual(call_kwargs["app_name"], "my_app") + + +class TestResponseExtraction(unittest.TestCase): + """Tests for extraction of the final response from the event stream.""" + + def test_returns_none_when_no_final_response(self): + """Agent emits only non-final events; inference should be None.""" + agent = _make_mock_agent() + + # Build a runner that yields only non-final events + non_final_event = mock.MagicMock() + non_final_event.is_final_response.return_value = False + + async def _async_gen(*args, **kwargs): + yield non_final_event + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=["hello"], model=runner)) + + self.assertEqual(len(results), 1) + self.assertIsNone(results[0].inference) + + def test_returns_none_when_final_event_has_no_content(self): + agent = _make_mock_agent() + + event = mock.MagicMock() + event.is_final_response.return_value = True + event.content = None + + async def _async_gen(*args, **kwargs): + yield event + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=["hello"], model=runner)) + + self.assertIsNone(results[0].inference) + + def test_stops_after_first_final_response(self): + """Multiple final events: only the first one's text should be used.""" + agent = _make_mock_agent() + + def _make_event(text: str): + part = mock.MagicMock() + part.text = text + content = mock.MagicMock() + content.parts = [part] + event = mock.MagicMock() + event.is_final_response.return_value = True + event.content = content + return event + + async def _async_gen(*args, **kwargs): + yield _make_event("first") + yield _make_event("second") + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=["hi"], model=runner)) + + self.assertEqual(results[0].inference, "first") + + def test_invoke_agent_static_method_directly(self): + """Unit test the async _invoke_agent helper directly.""" + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="direct result") + + result = asyncio.run( + ADKAgentModelHandler._invoke_agent( + runner, "user", "session-1", mock.MagicMock())) + self.assertEqual(result, "direct result") + + +if __name__ == '__main__': + unittest.main() From f3514c03ddafb975bf9cf16c1655307e92e29aa2 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 23 Mar 2026 10:41:54 -0400 Subject: [PATCH 02/12] Small cleanup --- .../ml/inference/agent_development_kit.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 59dc0cfb2e08..7b5a396e16e1 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -81,7 +81,7 @@ _AgentOrFactory = Union["Agent", Callable[[], "Agent"]] -class ADKAgentModelHandler(ModelHandler[Union[str, Any], PredictionResult, +class ADKAgentModelHandler(ModelHandler[Union[str, genai_types.Content], PredictionResult, "Runner"]): """ModelHandler for running ADK agents with the Beam RunInference transform. @@ -184,7 +184,7 @@ def load_model(self) -> "Runner": def run_inference( self, - batch: Sequence[Union[str, Any]], + batch: Sequence[Union[str, genai_types.Content]], model: "Runner", inference_args: Optional[dict[str, Any]] = None, ) -> Iterable[PredictionResult]: @@ -254,7 +254,7 @@ async def _invoke_agent( runner: "Runner", user_id: str, session_id: str, - message: Any, + message: genai_types.Content, ) -> Optional[str]: """Drives the ADK event loop and returns the final response text. @@ -268,7 +268,6 @@ async def _invoke_agent( The text of the agent's final response, or ``None`` if the agent produced no final text response. """ - final_text: Optional[str] = None async for event in runner.run_async( user_id=user_id, session_id=session_id, @@ -276,9 +275,8 @@ async def _invoke_agent( ): if event.is_final_response(): if event.content and event.content.parts: - final_text = event.content.parts[0].text - break - return final_text + return event.content.parts[0].text + return None def get_metrics_namespace(self) -> str: return "ADKAgentModelHandler" From 31bebf53351ce47a4bee5805618b959df4c8fbdd Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 23 Mar 2026 10:58:26 -0400 Subject: [PATCH 03/12] CHANGES --- CHANGES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES.md b/CHANGES.md index e91da103c30e..3d9773c97e39 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ ## New Features / Improvements +* Added `ADKAgentModelHandler` for running Google Agent Development Kit (ADK) agents (Python). * Added support for large pipeline options via a file (Python) ([#37370](https://github.com/apache/beam/issues/37370)). * Supported infer schema from dataclass (Python) ([#22085](https://github.com/apache/beam/issues/22085)). Default coder for typehint-ed (or set with_output_type) for non-frozen dataclasses changed to RowCoder. To preserve the old behavior (fast primitive coder), explicitly register the type with FastPrimitiveCoder. * Updates minimum Go version to 1.26.1 ([#37897](https://github.com/apache/beam/issues/37897)). From 69c8beb8d64d801734316e807ab95e5b09c9f033 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 23 Mar 2026 12:23:53 -0400 Subject: [PATCH 04/12] Fix up some tests --- .../ml/inference/agent_development_kit.py | 15 +++--- .../inference/agent_development_kit_test.py | 48 +++++++------------ sdks/python/setup.py | 1 + 3 files changed, 27 insertions(+), 37 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 7b5a396e16e1..eab196b3af1c 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -70,10 +70,13 @@ from google.adk.runners import Runner from google.adk.sessions import BaseSessionService from google.adk.sessions import InMemorySessionService - from google.genai import types as genai_types + from google.genai.types import Content as genai_Content + from google.genai.types import Part as genai_Part ADK_AVAILABLE = True except ImportError: ADK_AVAILABLE = False + genai_Content = Any + genai_Part = Any LOGGER = logging.getLogger("ADKAgentModelHandler") @@ -81,7 +84,7 @@ _AgentOrFactory = Union["Agent", Callable[[], "Agent"]] -class ADKAgentModelHandler(ModelHandler[Union[str, genai_types.Content], PredictionResult, +class ADKAgentModelHandler(ModelHandler[Union[str, genai_Content], PredictionResult, "Runner"]): """ModelHandler for running ADK agents with the Beam RunInference transform. @@ -184,7 +187,7 @@ def load_model(self) -> "Runner": def run_inference( self, - batch: Sequence[Union[str, genai_types.Content]], + batch: Sequence[Union[str, genai_Content]], model: "Runner", inference_args: Optional[dict[str, Any]] = None, ) -> Iterable[PredictionResult]: @@ -231,8 +234,8 @@ def run_inference( # Wrap plain strings in a Content object if isinstance(element, str): - message = genai_types.Content( - role="user", parts=[genai_types.Part(text=element)]) + message = genai_Content( + role="user", parts=[genai_Part(text=element)]) else: # Assume the caller has already constructed a types.Content object message = element @@ -254,7 +257,7 @@ async def _invoke_agent( runner: "Runner", user_id: str, session_id: str, - message: genai_types.Content, + message: genai_Content, ) -> Optional[str]: """Drives the ADK event loop and returns the final response text. diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py index 7bd77c52ff12..fab865c0c776 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -23,6 +23,9 @@ try: from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler from apache_beam.ml.inference.base import PredictionResult + from google.adk.agents.llm_agent import Agent + from google.adk.runners import Runner + from google.adk.sessions import InMemorySessionService except ImportError: raise unittest.SkipTest('google-adk dependencies are not installed') @@ -104,21 +107,22 @@ def test_metrics_namespace(self): class TestLoadModel(unittest.TestCase): """Tests for load_model / Runner construction.""" - @mock.patch(f"{_MODULE}.Runner") - @mock.patch(f"{_MODULE}.InMemorySessionService") - def test_load_model_with_agent_object( - self, mock_session_cls, mock_runner_cls): - agent = _make_mock_agent() + def test_load_model_with_agent_object(self): + def get_current_time(city: str) -> dict: + """Returns the current time in a specified city.""" + return {"status": "success", "city": city, "time": "10:30 AM"} + + agent = Agent( + model='gemini-3-flash-preview', + name='root_agent', + description="Tells the current time in a specified city.", + instruction="You are a helpful assistant that tells the current time in cities. Use the 'get_current_time' tool for this purpose.", + tools=[get_current_time], + ) handler = ADKAgentModelHandler(agent=agent, app_name="test_app") + runner = handler.load_model() - handler.load_model() - - mock_session_cls.assert_called_once() - mock_runner_cls.assert_called_once_with( - agent=agent, - app_name="test_app", - session_service=mock_session_cls.return_value, - ) + self.assertEqual(agent, runner.agent) @mock.patch(f"{_MODULE}.Runner") @mock.patch(f"{_MODULE}.InMemorySessionService") @@ -136,24 +140,6 @@ def test_load_model_calls_factory(self, mock_session_cls, mock_runner_cls): session_service=mock_session_cls.return_value, ) - @mock.patch(f"{_MODULE}.Runner") - def test_load_model_uses_custom_session_service(self, mock_runner_cls): - agent = _make_mock_agent() - custom_session_service = mock.MagicMock() - session_factory = mock.MagicMock(return_value=custom_session_service) - - handler = ADKAgentModelHandler( - agent=agent, session_service_factory=session_factory) - handler.load_model() - - session_factory.assert_called_once() - mock_runner_cls.assert_called_once_with( - agent=agent, - app_name="beam_inference", - session_service=custom_session_service, - ) - - class TestRunInference(unittest.TestCase): """Tests for run_inference output and batching.""" diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 1ad37f6f0243..24bc25410a2f 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -164,6 +164,7 @@ def cythonize(*args, **kwargs): ml_base = [ 'embeddings>=0.0.4', # 0.0.3 crashes setuptools + 'google-adk', 'onnxruntime', 'langchain', 'sentence-transformers>=2.2.2', From 633e08a6d2a507a22ec0b560d40b4b839fa52f61 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 23 Mar 2026 15:45:48 -0400 Subject: [PATCH 05/12] Linting --- .../ml/inference/agent_development_kit.py | 11 +++++------ .../inference/agent_development_kit_test.py | 19 ++++++++----------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index eab196b3af1c..3512fbd14cd8 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -84,7 +84,8 @@ _AgentOrFactory = Union["Agent", Callable[[], "Agent"]] -class ADKAgentModelHandler(ModelHandler[Union[str, genai_Content], PredictionResult, +class ADKAgentModelHandler(ModelHandler[Union[str, genai_Content], + PredictionResult, "Runner"]): """ModelHandler for running ADK agents with the Beam RunInference transform. @@ -118,13 +119,12 @@ class ADKAgentModelHandler(ModelHandler[Union[str, genai_Content], PredictionRes element_size_fn: Optional function that returns the size (weight) of an element. """ - def __init__( self, agent: _AgentOrFactory, app_name: str = "beam_inference", - session_service_factory: Optional[Callable[[], "BaseSessionService"]] = - None, + session_service_factory: Optional[Callable[[], + "BaseSessionService"]] = None, *, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -234,8 +234,7 @@ def run_inference( # Wrap plain strings in a Content object if isinstance(element, str): - message = genai_Content( - role="user", parts=[genai_Part(text=element)]) + message = genai_Content(role="user", parts=[genai_Part(text=element)]) else: # Assume the caller has already constructed a types.Content object message = element diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py index fab865c0c776..68a554b37b25 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -21,11 +21,10 @@ from unittest import mock try: + from google.adk.agents.llm_agent import Agent + from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler from apache_beam.ml.inference.base import PredictionResult - from google.adk.agents.llm_agent import Agent - from google.adk.runners import Runner - from google.adk.sessions import InMemorySessionService except ImportError: raise unittest.SkipTest('google-adk dependencies are not installed') @@ -72,7 +71,6 @@ async def _async_gen(*args, **kwargs): class TestADKAgentModelHandlerInit(unittest.TestCase): """Tests for __init__ argument validation.""" - def test_raises_if_agent_is_none(self): with self.assertRaises((ValueError, TypeError)): ADKAgentModelHandler(agent=None) # type: ignore[arg-type] @@ -106,17 +104,18 @@ def test_metrics_namespace(self): class TestLoadModel(unittest.TestCase): """Tests for load_model / Runner construction.""" - def test_load_model_with_agent_object(self): def get_current_time(city: str) -> dict: - """Returns the current time in a specified city.""" - return {"status": "success", "city": city, "time": "10:30 AM"} + """Returns the current time in a specified city.""" + return {"status": "success", "city": city, "time": "10:30 AM"} agent = Agent( model='gemini-3-flash-preview', name='root_agent', description="Tells the current time in a specified city.", - instruction="You are a helpful assistant that tells the current time in cities. Use the 'get_current_time' tool for this purpose.", + instruction= + "You are a helpful assistant that tells the current time in cities. " + "Use the 'get_current_time' tool for this purpose.", tools=[get_current_time], ) handler = ADKAgentModelHandler(agent=agent, app_name="test_app") @@ -140,9 +139,9 @@ def test_load_model_calls_factory(self, mock_session_cls, mock_runner_cls): session_service=mock_session_cls.return_value, ) + class TestRunInference(unittest.TestCase): """Tests for run_inference output and batching.""" - def test_string_input_returns_prediction_result(self): agent = _make_mock_agent() runner = _make_mock_runner(agent, final_text="Paris") @@ -211,7 +210,6 @@ def test_custom_user_id_passed_to_runner(self): class TestSessionManagement(unittest.TestCase): """Tests for session creation and session_id handling.""" - def test_each_element_gets_unique_session_by_default(self): agent = _make_mock_agent() runner = _make_mock_runner(agent) @@ -256,7 +254,6 @@ def test_session_created_with_correct_app_name(self): class TestResponseExtraction(unittest.TestCase): """Tests for extraction of the final response from the event stream.""" - def test_returns_none_when_no_final_response(self): """Agent emits only non-final events; inference should be None.""" agent = _make_mock_agent() From c0cf2f1c0b460beaaa355b65c1217a4f52ebe2d2 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Mon, 23 Mar 2026 16:33:48 -0400 Subject: [PATCH 06/12] lint --- CHANGES.md | 2 +- sdks/python/apache_beam/ml/inference/agent_development_kit.py | 4 ++-- .../apache_beam/ml/inference/agent_development_kit_test.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 3d9773c97e39..8b526905af04 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,7 +69,7 @@ ## New Features / Improvements -* Added `ADKAgentModelHandler` for running Google Agent Development Kit (ADK) agents (Python). +* Added `ADKAgentModelHandler` for running Google Agent Development Kit (ADK) agents (Python) ([#37917](https://github.com/apache/beam/issues/37917)). * Added support for large pipeline options via a file (Python) ([#37370](https://github.com/apache/beam/issues/37370)). * Supported infer schema from dataclass (Python) ([#22085](https://github.com/apache/beam/issues/22085)). Default coder for typehint-ed (or set with_output_type) for non-frozen dataclasses changed to RowCoder. To preserve the old behavior (fast primitive coder), explicitly register the type with FastPrimitiveCoder. * Updates minimum Go version to 1.26.1 ([#37897](https://github.com/apache/beam/issues/37897)). diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 3512fbd14cd8..7ebbd7ef5a69 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -75,8 +75,8 @@ ADK_AVAILABLE = True except ImportError: ADK_AVAILABLE = False - genai_Content = Any - genai_Part = Any + genai_Content = Any # type: ignore[assignment, misc] + genai_Part = Any # type: ignore[assignment, misc] LOGGER = logging.getLogger("ADKAgentModelHandler") diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py index 68a554b37b25..64ee3b87a542 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -73,7 +73,7 @@ class TestADKAgentModelHandlerInit(unittest.TestCase): """Tests for __init__ argument validation.""" def test_raises_if_agent_is_none(self): with self.assertRaises((ValueError, TypeError)): - ADKAgentModelHandler(agent=None) # type: ignore[arg-type] + ADKAgentModelHandler(agent=None) def test_accepts_agent_object(self): agent = _make_mock_agent() From 9363570d1ad23971b659916ab88fdf62b28d565d Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Tue, 24 Mar 2026 11:14:46 -0400 Subject: [PATCH 07/12] remove disclaimer, we don't do previews like this --- sdks/python/apache_beam/ml/inference/agent_development_kit.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 7ebbd7ef5a69..90c81a67497a 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -22,9 +22,6 @@ :class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the :class:`~apache_beam.ml.inference.base.RunInference` transform. -**NOTE:** This API and its implementation are under development and do not -provide backward compatibility guarantees. - Typical usage:: import apache_beam as beam From 1586202ce6a86bbf09ec594a4be2ebea080cbcea Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Tue, 24 Mar 2026 16:39:38 -0400 Subject: [PATCH 08/12] Fix gemini comments --- .../ml/inference/agent_development_kit.py | 16 +++++++++++++--- .../ml/inference/agent_development_kit_test.py | 2 +- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 90c81a67497a..c2fff6ba4bda 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -217,8 +217,9 @@ def run_inference( inference_args = {} user_id: str = inference_args.get("user_id", "beam_user") + agent_invocations = [] + elements_with_sessions = [] - results = [] for element in batch: session_id: str = inference_args.get("session_id", str(uuid.uuid4())) @@ -236,13 +237,22 @@ def run_inference( # Assume the caller has already constructed a types.Content object message = element - response_text = asyncio.run( + agent_invocations.append( self._invoke_agent(model, user_id, session_id, message)) + elements_with_sessions.append(element) + + # Run all agent invocations concurrently + async def _run_concurrently(): + return await asyncio.gather(*agent_invocations) + response_texts = asyncio.run(_run_concurrently()) + + results = [] + for i, element in enumerate(elements_with_sessions): results.append( PredictionResult( example=element, - inference=response_text, + inference=response_texts[i], model_id=model.agent.name, )) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py index 64ee3b87a542..02243f092cd1 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -21,7 +21,7 @@ from unittest import mock try: - from google.adk.agents.llm_agent import Agent + from google.adk.agents import Agent from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler from apache_beam.ml.inference.base import PredictionResult From 5d0706cc5adc40bd87128734728ef602b2b2f93c Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 26 Mar 2026 09:05:55 -0400 Subject: [PATCH 09/12] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../ml/inference/agent_development_kit.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index c2fff6ba4bda..a04455aed5ac 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -224,11 +224,15 @@ def run_inference( session_id: str = inference_args.get("session_id", str(uuid.uuid4())) # Ensure a session exists for this invocation - model.session_service.create_session( - app_name=self._app_name, - user_id=user_id, - session_id=session_id, - ) + try: + model.session_service.create_session( + app_name=self._app_name, + user_id=user_id, + session_id=session_id, + ) + except SessionExistsError: + # It's okay if the session already exists for shared session IDs. + pass # Wrap plain strings in a Content object if isinstance(element, str): @@ -283,8 +287,8 @@ async def _invoke_agent( new_message=message, ): if event.is_final_response(): - if event.content and event.content.parts: - return event.content.parts[0].text + if event.content: + return event.content.text return None def get_metrics_namespace(self) -> str: From fc3e0d55976be02f75689e15bc9f83387b903314 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 26 Mar 2026 09:36:26 -0400 Subject: [PATCH 10/12] Update sdks/python/apache_beam/ml/inference/agent_development_kit.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- sdks/python/apache_beam/ml/inference/agent_development_kit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index a04455aed5ac..74b274935cb5 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -230,7 +230,7 @@ def run_inference( user_id=user_id, session_id=session_id, ) - except SessionExistsError: + except sessions.SessionExistsError: # It's okay if the session already exists for shared session IDs. pass From 2b6351e59171d0895931fa28af871fad7f9965f3 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 26 Mar 2026 09:39:38 -0400 Subject: [PATCH 11/12] Update sdks/python/apache_beam/ml/inference/agent_development_kit.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- sdks/python/apache_beam/ml/inference/agent_development_kit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 74b274935cb5..33e5144d35b0 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -65,6 +65,7 @@ try: from google.adk.agents import Agent from google.adk.runners import Runner + from google.adk import sessions from google.adk.sessions import BaseSessionService from google.adk.sessions import InMemorySessionService from google.genai.types import Content as genai_Content From 64d02c86a1a66fad6b29f988c2ebcc093cc33062 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 27 Mar 2026 09:43:37 -0400 Subject: [PATCH 12/12] tests + lint --- .../apache_beam/ml/inference/agent_development_kit.py | 2 +- .../ml/inference/agent_development_kit_test.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py index 33e5144d35b0..196fc62c1934 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -63,9 +63,9 @@ from apache_beam.ml.inference.base import PredictionResult try: + from google.adk import sessions from google.adk.agents import Agent from google.adk.runners import Runner - from google.adk import sessions from google.adk.sessions import BaseSessionService from google.adk.sessions import InMemorySessionService from google.genai.types import Content as genai_Content diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py index 02243f092cd1..6d59bceb9d39 100644 --- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -42,11 +42,8 @@ def _make_mock_runner( ) -> mock.MagicMock: """Returns a mock Runner whose run_async yields one final-response event.""" # Build a mock event that looks like a final response - part = mock.MagicMock() - part.text = final_text - content = mock.MagicMock() - content.parts = [part] + content.text = final_text event = mock.MagicMock() event.is_final_response.return_value = True @@ -301,10 +298,8 @@ def test_stops_after_first_final_response(self): agent = _make_mock_agent() def _make_event(text: str): - part = mock.MagicMock() - part.text = text content = mock.MagicMock() - content.parts = [part] + content.text = text event = mock.MagicMock() event.is_final_response.return_value = True event.content = content