diff --git a/CHANGES.md b/CHANGES.md index 6d26251f138c..2667838e52d7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ ## New Features / Improvements +* Added `ADKAgentModelHandler` for running Google Agent Development Kit (ADK) agents (Python) ([#37917](https://github.com/apache/beam/issues/37917)). * (Python) Added exception chaining to preserve error context in CloudSQLEnrichmentHandler, processes utilities, and core transforms ([#37422](https://github.com/apache/beam/issues/37422)). * (Python) Added a pipeline option `--experiments=pip_no_build_isolation` to disable build isolation when installing dependencies in the runtime environment ([#37331](https://github.com/apache/beam/issues/37331)). * (Go) Added OrderedListState support to the Go SDK stateful DoFn API ([#37629](https://github.com/apache/beam/issues/37629)). diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py b/sdks/python/apache_beam/ml/inference/agent_development_kit.py new file mode 100644 index 000000000000..196fc62c1934 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py @@ -0,0 +1,296 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""ModelHandler for running agents built with the Google Agent Development Kit. + +This module provides :class:`ADKAgentModelHandler`, a Beam +:class:`~apache_beam.ml.inference.base.ModelHandler` that wraps an ADK +:class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the +:class:`~apache_beam.ml.inference.base.RunInference` transform. + +Typical usage:: + + import apache_beam as beam + from apache_beam.ml.inference.base import RunInference + from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler + from google.adk.agents import LlmAgent + + agent = LlmAgent( + name="my_agent", + model="gemini-2.0-flash", + instruction="You are a helpful assistant.", + ) + + with beam.Pipeline() as p: + results = ( + p + | beam.Create(["What is the capital of France?"]) + | RunInference(ADKAgentModelHandler(agent=agent)) + ) + +If your agent contains state that is not picklable (e.g. tool closures that +capture unpicklable objects), pass a zero-arg factory callable instead:: + + handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...)) + +""" + +import asyncio +import logging +import uuid +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Sequence +from typing import Any +from typing import Optional +from typing import Union + +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult + +try: + from google.adk import sessions + from google.adk.agents import Agent + from google.adk.runners import Runner + from google.adk.sessions import BaseSessionService + from google.adk.sessions import InMemorySessionService + from google.genai.types import Content as genai_Content + from google.genai.types import Part as genai_Part + ADK_AVAILABLE = True +except ImportError: + ADK_AVAILABLE = False + genai_Content = Any # type: ignore[assignment, misc] + genai_Part = Any # type: ignore[assignment, misc] + +LOGGER = logging.getLogger("ADKAgentModelHandler") + +# Type alias for an agent or factory that produces one +_AgentOrFactory = Union["Agent", Callable[[], "Agent"]] + + +class ADKAgentModelHandler(ModelHandler[Union[str, genai_Content], + PredictionResult, + "Runner"]): + """ModelHandler for running ADK agents with the Beam RunInference transform. + + Accepts either a fully constructed :class:`google.adk.agents.Agent` or a + zero-arg factory callable that produces one. The factory form is useful when + the agent contains state that is not picklable and therefore cannot be + serialized alongside the pipeline graph. + + Each call to :meth:`run_inference` invokes the agent once per element in the + batch. By default every invocation uses a fresh, isolated session (stateless). + Stateful multi-turn conversations can be achieved by passing a ``session_id`` + key inside ``inference_args``; elements sharing the same ``session_id`` will + continue the same conversation history. + + Args: + agent: A pre-constructed :class:`~google.adk.agents.Agent` instance, or a + zero-arg callable that returns one. The callable form defers agent + construction to worker ``load_model`` time, which is useful when the + agent cannot be serialized. + app_name: The ADK application name used to namespace sessions. Defaults to + ``"beam_inference"``. + session_service_factory: Optional zero-arg callable returning a + :class:`~google.adk.sessions.BaseSessionService`. When ``None``, an + :class:`~google.adk.sessions.InMemorySessionService` is created + automatically. + min_batch_size: Optional minimum batch size. + max_batch_size: Optional maximum batch size. + max_batch_duration_secs: Optional maximum time to buffer a batch before + emitting; used in streaming contexts. + max_batch_weight: Optional maximum total weight of a batch. + element_size_fn: Optional function that returns the size (weight) of an + element. + """ + def __init__( + self, + agent: _AgentOrFactory, + app_name: str = "beam_inference", + session_service_factory: Optional[Callable[[], + "BaseSessionService"]] = None, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, + **kwargs): + if not ADK_AVAILABLE: + raise ImportError( + "google-adk is required to use ADKAgentModelHandler. " + "Install it with: pip install google-adk") + + if agent is None: + raise ValueError("'agent' must be an Agent instance or a callable.") + + self._agent_or_factory = agent + self._app_name = app_name + self._session_service_factory = session_service_factory + + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + **kwargs) + + def load_model(self) -> "Runner": + """Instantiates the ADK Runner on the worker. + + Resolves the agent (calling the factory if a callable was provided), then + creates a :class:`~google.adk.runners.Runner` backed by the configured + session service. + + Returns: + A fully initialised :class:`~google.adk.runners.Runner`. + """ + if callable(self._agent_or_factory) and not isinstance( + self._agent_or_factory, Agent): + agent = self._agent_or_factory() + else: + agent = self._agent_or_factory + + if self._session_service_factory is not None: + session_service = self._session_service_factory() + else: + session_service = InMemorySessionService() + + runner = Runner( + agent=agent, + app_name=self._app_name, + session_service=session_service, + ) + LOGGER.info( + "Loaded ADK Runner for agent '%s' (app_name='%s')", + agent.name, + self._app_name, + ) + return runner + + def run_inference( + self, + batch: Sequence[Union[str, genai_Content]], + model: "Runner", + inference_args: Optional[dict[str, Any]] = None, + ) -> Iterable[PredictionResult]: + """Runs the ADK agent on each element in the batch. + + Each element is sent to the agent as a new user turn. The final response + text from the agent is returned as the ``inference`` field of a + :class:`~apache_beam.ml.inference.base.PredictionResult`. + + Args: + batch: A sequence of inputs, each of which is either a ``str`` (the user + message text) or a :class:`google.genai.types.Content` object (for + richer multi-part messages). + model: The :class:`~google.adk.runners.Runner` returned by + :meth:`load_model`. + inference_args: Optional dict of extra arguments. Supported keys: + + - ``"session_id"`` (:class:`str`): If supplied, all elements in this + batch share this session ID, enabling stateful multi-turn + conversations. If omitted, each element receives a unique auto- + generated session ID. + - ``"user_id"`` (:class:`str`): The user identifier to pass to the + runner. Defaults to ``"beam_user"``. + + Returns: + An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`, + one per input element. + """ + if inference_args is None: + inference_args = {} + + user_id: str = inference_args.get("user_id", "beam_user") + agent_invocations = [] + elements_with_sessions = [] + + for element in batch: + session_id: str = inference_args.get("session_id", str(uuid.uuid4())) + + # Ensure a session exists for this invocation + try: + model.session_service.create_session( + app_name=self._app_name, + user_id=user_id, + session_id=session_id, + ) + except sessions.SessionExistsError: + # It's okay if the session already exists for shared session IDs. + pass + + # Wrap plain strings in a Content object + if isinstance(element, str): + message = genai_Content(role="user", parts=[genai_Part(text=element)]) + else: + # Assume the caller has already constructed a types.Content object + message = element + + agent_invocations.append( + self._invoke_agent(model, user_id, session_id, message)) + elements_with_sessions.append(element) + + # Run all agent invocations concurrently + async def _run_concurrently(): + return await asyncio.gather(*agent_invocations) + + response_texts = asyncio.run(_run_concurrently()) + + results = [] + for i, element in enumerate(elements_with_sessions): + results.append( + PredictionResult( + example=element, + inference=response_texts[i], + model_id=model.agent.name, + )) + + return results + + @staticmethod + async def _invoke_agent( + runner: "Runner", + user_id: str, + session_id: str, + message: genai_Content, + ) -> Optional[str]: + """Drives the ADK event loop and returns the final response text. + + Args: + runner: The ADK Runner to invoke. + user_id: The user ID for this invocation. + session_id: The session ID for this invocation. + message: The :class:`google.genai.types.Content` to send. + + Returns: + The text of the agent's final response, or ``None`` if the agent + produced no final text response. + """ + async for event in runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=message, + ): + if event.is_final_response(): + if event.content: + return event.content.text + return None + + def get_metrics_namespace(self) -> str: + return "ADKAgentModelHandler" diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py new file mode 100644 index 000000000000..6d59bceb9d39 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py @@ -0,0 +1,334 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import asyncio +import unittest +from unittest import mock + +try: + from google.adk.agents import Agent + + from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler + from apache_beam.ml.inference.base import PredictionResult +except ImportError: + raise unittest.SkipTest('google-adk dependencies are not installed') + + +def _make_mock_agent(name: str = "test_agent") -> mock.MagicMock: + """Returns a mock that quacks like a google.adk.agents.Agent.""" + agent = mock.MagicMock() + agent.name = name + return agent + + +def _make_mock_runner( + agent: mock.MagicMock, + final_text: str = "Hello from agent", +) -> mock.MagicMock: + """Returns a mock Runner whose run_async yields one final-response event.""" + # Build a mock event that looks like a final response + content = mock.MagicMock() + content.text = final_text + + event = mock.MagicMock() + event.is_final_response.return_value = True + event.content = content + + async def _async_gen(*args, **kwargs): + yield event + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + return runner + + +# --------------------------------------------------------------------------- +# Helper: patch ADK imports inside the module under test so tests work even +# when google-adk is installed (avoids constructing real ADK objects). +# --------------------------------------------------------------------------- +_MODULE = "apache_beam.ml.inference.agent_development_kit" + + +class TestADKAgentModelHandlerInit(unittest.TestCase): + """Tests for __init__ argument validation.""" + def test_raises_if_agent_is_none(self): + with self.assertRaises((ValueError, TypeError)): + ADKAgentModelHandler(agent=None) + + def test_accepts_agent_object(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent) + self.assertEqual(handler._agent_or_factory, agent) + + def test_accepts_agent_factory_callable(self): + agent = _make_mock_agent() + factory = lambda: agent + handler = ADKAgentModelHandler(agent=factory) + self.assertTrue(callable(handler._agent_or_factory)) + + def test_default_app_name(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent) + self.assertEqual(handler._app_name, "beam_inference") + + def test_custom_app_name(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent, app_name="my_app") + self.assertEqual(handler._app_name, "my_app") + + def test_metrics_namespace(self): + agent = _make_mock_agent() + handler = ADKAgentModelHandler(agent=agent) + self.assertEqual(handler.get_metrics_namespace(), "ADKAgentModelHandler") + + +class TestLoadModel(unittest.TestCase): + """Tests for load_model / Runner construction.""" + def test_load_model_with_agent_object(self): + def get_current_time(city: str) -> dict: + """Returns the current time in a specified city.""" + return {"status": "success", "city": city, "time": "10:30 AM"} + + agent = Agent( + model='gemini-3-flash-preview', + name='root_agent', + description="Tells the current time in a specified city.", + instruction= + "You are a helpful assistant that tells the current time in cities. " + "Use the 'get_current_time' tool for this purpose.", + tools=[get_current_time], + ) + handler = ADKAgentModelHandler(agent=agent, app_name="test_app") + runner = handler.load_model() + + self.assertEqual(agent, runner.agent) + + @mock.patch(f"{_MODULE}.Runner") + @mock.patch(f"{_MODULE}.InMemorySessionService") + def test_load_model_calls_factory(self, mock_session_cls, mock_runner_cls): + agent = _make_mock_agent() + factory = mock.MagicMock(return_value=agent) + + handler = ADKAgentModelHandler(agent=factory) + handler.load_model() + + factory.assert_called_once() + mock_runner_cls.assert_called_once_with( + agent=agent, + app_name="beam_inference", + session_service=mock_session_cls.return_value, + ) + + +class TestRunInference(unittest.TestCase): + """Tests for run_inference output and batching.""" + def test_string_input_returns_prediction_result(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="Paris") + + handler = ADKAgentModelHandler(agent=agent) + results = list( + handler.run_inference( + batch=["What is the capital of France?"], model=runner)) + + self.assertEqual(len(results), 1) + pr = results[0] + self.assertIsInstance(pr, PredictionResult) + self.assertEqual(pr.example, "What is the capital of France?") + self.assertEqual(pr.inference, "Paris") + self.assertEqual(pr.model_id, "test_agent") + + def test_batch_of_strings(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="answer") + + handler = ADKAgentModelHandler(agent=agent) + results = list( + handler.run_inference(batch=["q1", "q2", "q3"], model=runner)) + + self.assertEqual(len(results), 3) + self.assertEqual([r.example for r in results], ["q1", "q2", "q3"]) + + def test_content_object_input(self): + """Non-string inputs (types.Content) are passed through unchanged.""" + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="Berlin") + + content_input = mock.MagicMock() # simulates types.Content + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=[content_input], model=runner)) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0].example, content_input) + self.assertEqual(results[0].inference, "Berlin") + + def test_none_inference_args_uses_defaults(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + results = list( + handler.run_inference( + batch=["hello"], model=runner, inference_args=None)) + self.assertEqual(len(results), 1) + + def test_custom_user_id_passed_to_runner(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + handler.run_inference( + batch=["hi"], + model=runner, + inference_args={"user_id": "custom_user"}, + ) + + call_kwargs = runner.run_async.call_args[1] + self.assertEqual(call_kwargs["user_id"], "custom_user") + + +class TestSessionManagement(unittest.TestCase): + """Tests for session creation and session_id handling.""" + def test_each_element_gets_unique_session_by_default(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + handler.run_inference(batch=["a", "b", "c"], model=runner) + + # create_session should have been called 3 times with distinct session IDs + calls = runner.session_service.create_session.call_args_list + self.assertEqual(len(calls), 3) + session_ids = [c[1]["session_id"] for c in calls] + self.assertEqual(len(set(session_ids)), 3, "Expected unique session IDs") + + def test_shared_session_id_from_inference_args(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent) + handler.run_inference( + batch=["turn1", "turn2"], + model=runner, + inference_args={"session_id": "my-session"}, + ) + + calls = runner.session_service.create_session.call_args_list + session_ids = [c[1]["session_id"] for c in calls] + self.assertTrue( + all(sid == "my-session" for sid in session_ids), + "All elements should share the provided session_id", + ) + + def test_session_created_with_correct_app_name(self): + agent = _make_mock_agent() + runner = _make_mock_runner(agent) + + handler = ADKAgentModelHandler(agent=agent, app_name="my_app") + handler.run_inference(batch=["hello"], model=runner) + + call_kwargs = runner.session_service.create_session.call_args[1] + self.assertEqual(call_kwargs["app_name"], "my_app") + + +class TestResponseExtraction(unittest.TestCase): + """Tests for extraction of the final response from the event stream.""" + def test_returns_none_when_no_final_response(self): + """Agent emits only non-final events; inference should be None.""" + agent = _make_mock_agent() + + # Build a runner that yields only non-final events + non_final_event = mock.MagicMock() + non_final_event.is_final_response.return_value = False + + async def _async_gen(*args, **kwargs): + yield non_final_event + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=["hello"], model=runner)) + + self.assertEqual(len(results), 1) + self.assertIsNone(results[0].inference) + + def test_returns_none_when_final_event_has_no_content(self): + agent = _make_mock_agent() + + event = mock.MagicMock() + event.is_final_response.return_value = True + event.content = None + + async def _async_gen(*args, **kwargs): + yield event + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=["hello"], model=runner)) + + self.assertIsNone(results[0].inference) + + def test_stops_after_first_final_response(self): + """Multiple final events: only the first one's text should be used.""" + agent = _make_mock_agent() + + def _make_event(text: str): + content = mock.MagicMock() + content.text = text + event = mock.MagicMock() + event.is_final_response.return_value = True + event.content = content + return event + + async def _async_gen(*args, **kwargs): + yield _make_event("first") + yield _make_event("second") + + runner = mock.MagicMock() + runner.agent = agent + runner.run_async = mock.MagicMock(side_effect=_async_gen) + runner.session_service = mock.MagicMock() + + handler = ADKAgentModelHandler(agent=agent) + results = list(handler.run_inference(batch=["hi"], model=runner)) + + self.assertEqual(results[0].inference, "first") + + def test_invoke_agent_static_method_directly(self): + """Unit test the async _invoke_agent helper directly.""" + agent = _make_mock_agent() + runner = _make_mock_runner(agent, final_text="direct result") + + result = asyncio.run( + ADKAgentModelHandler._invoke_agent( + runner, "user", "session-1", mock.MagicMock())) + self.assertEqual(result, "direct result") + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 1ad37f6f0243..24bc25410a2f 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -164,6 +164,7 @@ def cythonize(*args, **kwargs): ml_base = [ 'embeddings>=0.0.4', # 0.0.3 crashes setuptools + 'google-adk', 'onnxruntime', 'langchain', 'sentence-transformers>=2.2.2',