From 59c020b6542cfbb510e169f5abcaa56a532b48f6 Mon Sep 17 00:00:00 2001 From: Raman369AI Date: Tue, 12 May 2026 22:46:00 -0500 Subject: [PATCH 1/8] fix: terminate infinite retry loop in RunSkillScriptTool on SCRIPT_NOT_FOUND Mirrors the LoadSkillResourceTool fix (#5651): when the LLM hallucinates a script path that does not exist in the skill's scripts/ directory, RunSkillScriptTool returns SCRIPT_NOT_FOUND as a soft error and the model retries with a different plausible path each turn. Nothing escalates between attempts, so the loop terminates only when RunConfig.max_llm_calls (default 500) is exhausted. Adds an invocation-scoped failure counter under temp:_adk_skill_script_not_found_count_. The first miss within an invocation still returns SCRIPT_NOT_FOUND (unchanged); any subsequent miss returns the new SCRIPT_NOT_FOUND_FATAL with an explicit "do not retry, report and stop" message and the failure count. The counter is path-agnostic, so it fires even when the model hallucinates a different script path on each retry. The temp: prefix keeps the key out of durable session storage; the invocation_id suffix isolates in-memory backends where temp: keys are not auto-cleared between invocations. The default system prompt also gains a no-retry rule for run_skill_script errors so the model has a semantic reason to stop, complementing the code-level guard. Tests cover: first-miss soft error preserved, 2nd-miss-same-path escalates, 2nd-miss-different-path escalates (path-agnostic), per-invocation isolation, counter key uses temp: prefix. --- src/google/adk/tools/skill_toolset.py | 22 ++++ tests/unittests/tools/test_skill_toolset.py | 115 +++++++++++++++++++- 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 3c60bd5918..b599ae192d 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -84,6 +84,9 @@ "4. Use `run_skill_script` to run scripts from a skill's `scripts/` " "directory. Use `load_skill_resource` to view script content first if " "needed.\n" + "5. If `run_skill_script` returns an error (for example " + "`SCRIPT_NOT_FOUND`), do not retry the same script or guess a different " + "script path. Report the error to the user and stop.\n" ) @@ -840,6 +843,25 @@ async def run_async( script = skill.resources.get_script(file_path) if script is None: + # Invocation-scoped failure counter. Counts SCRIPT_NOT_FOUND across ALL + # paths so the guard fires even when the LLM hallucinates a different + # script path on each retry. The `temp:` prefix prevents persistence to + # durable session storage; invocation_id isolates in-memory backends. + counter_key = ( + f"temp:_adk_skill_script_not_found_count_{tool_context.invocation_id}" + ) + fail_count = int(tool_context.state.get(counter_key) or 0) + 1 + tool_context.state[counter_key] = fail_count + if fail_count > 1: + return { + "error": ( + f"Script '{file_path}' not found in skill '{skill_name}'." + f" This is script lookup failure #{fail_count} this" + " invocation. Do not retry any script path — report the" + " error to the user and stop." + ), + "error_code": "SCRIPT_NOT_FOUND_FATAL", + } return { "error": f"Script '{file_path}' not found in skill '{skill_name}'.", "error_code": "SCRIPT_NOT_FOUND", diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index b58c01b91b..1abfd411fc 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -499,7 +499,7 @@ async def test_scripts_resource_not_found(mock_skill1, tool_context_instance): # RunSkillScriptTool tests -def _make_tool_context_with_agent(agent=None): +def _make_tool_context_with_agent(agent=None, invocation_id="test_invocation"): """Creates a mock ToolContext with _invocation_context.agent.""" ctx = mock.MagicMock(spec=tool_context.ToolContext) ctx._invocation_context = mock.MagicMock() @@ -507,6 +507,7 @@ def _make_tool_context_with_agent(agent=None): ctx._invocation_context.agent.name = "test_agent" ctx._invocation_context.agent_states = {} ctx.agent_name = "test_agent" + ctx.invocation_id = invocation_id ctx.state = {} return ctx @@ -577,6 +578,118 @@ async def test_execute_script_script_not_found(mock_skill1): tool_context=ctx, ) assert result["error_code"] == "SCRIPT_NOT_FOUND" + assert result["error"] == ( + "Script 'nonexistent.py' not found in skill 'skill1'." + ) + + +@pytest.mark.asyncio +async def test_execute_script_repeated_failure_escalates_to_fatal(mock_skill1): + """Any second SCRIPT_NOT_FOUND within an invocation returns SCRIPT_NOT_FOUND_FATAL.""" + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + + args = {"skill_name": "skill1", "file_path": "scripts/nonexistent.py"} + + result1 = await tool.run_async(args=args, tool_context=ctx) + assert result1["error_code"] == "SCRIPT_NOT_FOUND" + + result2 = await tool.run_async(args=args, tool_context=ctx) + assert result2["error_code"] == "SCRIPT_NOT_FOUND_FATAL" + assert "Do not retry" in result2["error"] + assert "stop" in result2["error"].lower() + assert "failure #2" in result2["error"] + + +@pytest.mark.asyncio +async def test_execute_script_different_path_also_escalates_to_fatal( + mock_skill1, +): + """A different missing script on the second call still escalates to SCRIPT_NOT_FOUND_FATAL. + + The counter is path-agnostic: any second not-found within the same invocation + is fatal, even when the LLM hallucinates a different script path on each + retry. + """ + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + + result1 = await tool.run_async( + args={"skill_name": "skill1", "file_path": "scripts/missing_a.py"}, + tool_context=ctx, + ) + assert result1["error_code"] == "SCRIPT_NOT_FOUND" + + result2 = await tool.run_async( + args={"skill_name": "skill1", "file_path": "scripts/missing_b.py"}, + tool_context=ctx, + ) + assert result2["error_code"] == "SCRIPT_NOT_FOUND_FATAL" + assert "Do not retry" in result2["error"] + + +@pytest.mark.asyncio +async def test_execute_script_failures_isolated_per_invocation(mock_skill1): + """Failure counter does not leak across invocations. + + A SCRIPT_NOT_FOUND in invocation A must not increment invocation B's + counter; invocation B's first missing-script call must still return the + soft error, even when both invocations share the same session state dict. + """ + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + + shared_state = {} + ctx_a = _make_tool_context_with_agent(invocation_id="inv_a") + ctx_a.state = shared_state + ctx_b = _make_tool_context_with_agent(invocation_id="inv_b") + ctx_b.state = shared_state + + # invocation A: one failure — counter for inv_a reaches 1 (soft). + result_a = await tool.run_async( + args={"skill_name": "skill1", "file_path": "scripts/typo.py"}, + tool_context=ctx_a, + ) + assert result_a["error_code"] == "SCRIPT_NOT_FOUND" + + # invocation B, first attempt (same path) — counter for inv_b = 1 (soft). + result_b1 = await tool.run_async( + args={"skill_name": "skill1", "file_path": "scripts/typo.py"}, + tool_context=ctx_b, + ) + assert result_b1["error_code"] == "SCRIPT_NOT_FOUND" + + # invocation B, second attempt (different path) — counter for inv_b = 2 (fatal). + result_b2 = await tool.run_async( + args={"skill_name": "skill1", "file_path": "scripts/other.py"}, + tool_context=ctx_b, + ) + assert result_b2["error_code"] == "SCRIPT_NOT_FOUND_FATAL" + + +@pytest.mark.asyncio +async def test_execute_script_counter_uses_temp_prefix(mock_skill1): + """Failure-counter key uses the `temp:` prefix so it is not persisted.""" + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + + await tool.run_async( + args={"skill_name": "skill1", "file_path": "scripts/missing.py"}, + tool_context=ctx, + ) + + # The counter key must start with `temp:` so it is trimmed from the event + # delta and never reaches durable storage. + guard_keys = [k for k in ctx.state if "skill_script_not_found_count" in k] + assert guard_keys, "Failure counter did not write a tracking key." + assert all(k.startswith("temp:") for k in guard_keys) @pytest.mark.asyncio From 6e6621c7680168cfb7e6ae36cb8718c529e46316 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 15 May 2026 17:49:32 -0700 Subject: [PATCH 2/8] perf: lazy-load service registries and split apps.app to cut cold start ~8% PiperOrigin-RevId: 916265957 --- src/google/adk/agents/invocation_context.py | 4 +- src/google/adk/apps/__init__.py | 22 +---- src/google/adk/apps/_configs.py | 95 ------------------- src/google/adk/apps/app.py | 81 ++++++++++++++-- src/google/adk/artifacts/__init__.py | 26 +---- src/google/adk/flows/llm_flows/single_flow.py | 2 +- src/google/adk/memory/__init__.py | 34 +++---- src/google/adk/plugins/__init__.py | 29 +----- src/google/adk/runners.py | 17 ++-- src/google/adk/sessions/__init__.py | 30 ++---- 10 files changed, 111 insertions(+), 229 deletions(-) delete mode 100644 src/google/adk/apps/_configs.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 8dc1a01af2..7fdbaee89b 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -25,8 +25,8 @@ from pydantic import Field from pydantic import PrivateAttr -from ..apps._configs import EventsCompactionConfig -from ..apps._configs import ResumabilityConfig +from ..apps.app import EventsCompactionConfig +from ..apps.app import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.auth_credential import AuthCredential from ..auth.credential_service.base_credential_service import BaseCredentialService diff --git a/src/google/adk/apps/__init__.py b/src/google/adk/apps/__init__.py index 319293967b..3a5d0b0643 100644 --- a/src/google/adk/apps/__init__.py +++ b/src/google/adk/apps/__init__.py @@ -12,28 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from ._configs import ResumabilityConfig - from .app import App +from .app import App +from .app import ResumabilityConfig __all__ = [ 'App', 'ResumabilityConfig', ] - -_LAZY_MEMBERS: dict[str, str] = { - 'App': 'app', - 'ResumabilityConfig': '_configs', -} - - -def __getattr__(name: str): - if name in _LAZY_MEMBERS: - module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') - return vars(module)[name] - raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/apps/_configs.py b/src/google/adk/apps/_configs.py deleted file mode 100644 index 87f3666ebd..0000000000 --- a/src/google/adk/apps/_configs.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import Field -from pydantic import model_validator - -from ..utils.feature_decorator import experimental -from .base_events_summarizer import BaseEventsSummarizer - - -@experimental -class ResumabilityConfig(BaseModel): - """The config of the resumability for an application. - - The "resumability" in ADK refers to the ability to: - 1. pause an invocation upon a long-running function call. - 2. resume an invocation from the last event, if it's paused or failed midway - through. - - Note: ADK resumes the invocation in a best-effort manner: - 1. Tool call to resume needs to be idempotent because we only guarantee - an at-least-once behavior once resumed. - 2. Any temporary / in-memory state will be lost upon resumption. - """ - - is_resumable: bool = False - """Whether the app supports agent resumption. - If enabled, the feature will be enabled for all agents in the app. - """ - - -@experimental -class EventsCompactionConfig(BaseModel): - """The config of event compaction for an application.""" - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - summarizer: Optional[BaseEventsSummarizer] = None - """The event summarizer to use for compaction.""" - - compaction_interval: int - """The number of *new* user-initiated invocations that, once - fully represented in the session's events, will trigger a compaction.""" - - overlap_size: int - """The number of preceding invocations to include from the - end of the last compacted range. This creates an overlap between consecutive - compacted summaries, maintaining context.""" - - token_threshold: Optional[int] = Field( - default=None, - gt=0, - ) - """Post-invocation token threshold trigger. - - If set, ADK will attempt a post-invocation compaction when the most recently - observed prompt token count meets or exceeds this threshold. - """ - - event_retention_size: Optional[int] = Field(default=None, ge=0) - """Post-invocation raw event retention size. - - If token-based post-invocation compaction is triggered, this keeps the last N - raw events un-compacted. - """ - - @model_validator(mode="after") - def _validate_token_params(self) -> EventsCompactionConfig: - token_threshold_set = self.token_threshold is not None - retention_size_set = self.event_retention_size is not None - if token_threshold_set != retention_size_set: - raise ValueError( - "token_threshold and event_retention_size must be set together." - ) - return self diff --git a/src/google/adk/apps/app.py b/src/google/adk/apps/app.py index 9bde128b7a..c20d581d9b 100644 --- a/src/google/adk/apps/app.py +++ b/src/google/adk/apps/app.py @@ -22,16 +22,9 @@ from ..agents.base_agent import BaseAgent from ..agents.context_cache_config import ContextCacheConfig +from ..apps.base_events_summarizer import BaseEventsSummarizer from ..plugins.base_plugin import BasePlugin -from ._configs import EventsCompactionConfig -from ._configs import ResumabilityConfig - -__all__ = [ - "App", - "EventsCompactionConfig", - "ResumabilityConfig", - "validate_app_name", -] +from ..utils.feature_decorator import experimental def validate_app_name(name: str) -> None: @@ -45,6 +38,76 @@ def validate_app_name(name: str) -> None: raise ValueError("App name cannot be 'user'; reserved for end-user input.") +@experimental +class ResumabilityConfig(BaseModel): + """The config of the resumability for an application. + + The "resumability" in ADK refers to the ability to: + 1. pause an invocation upon a long-running function call. + 2. resume an invocation from the last event, if it's paused or failed midway + through. + + Note: ADK resumes the invocation in a best-effort manner: + 1. Tool call to resume needs to be idempotent because we only guarantee + an at-least-once behavior once resumed. + 2. Any temporary / in-memory state will be lost upon resumption. + """ + + is_resumable: bool = False + """Whether the app supports agent resumption. + If enabled, the feature will be enabled for all agents in the app. + """ + + +@experimental +class EventsCompactionConfig(BaseModel): + """The config of event compaction for an application.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + summarizer: Optional[BaseEventsSummarizer] = None + """The event summarizer to use for compaction.""" + + compaction_interval: int + """The number of *new* user-initiated invocations that, once + fully represented in the session's events, will trigger a compaction.""" + + overlap_size: int + """The number of preceding invocations to include from the + end of the last compacted range. This creates an overlap between consecutive + compacted summaries, maintaining context.""" + + token_threshold: Optional[int] = Field( + default=None, + gt=0, + ) + """Post-invocation token threshold trigger. + + If set, ADK will attempt a post-invocation compaction when the most recently + observed prompt token count meets or exceeds this threshold. + """ + + event_retention_size: Optional[int] = Field(default=None, ge=0) + """Post-invocation raw event retention size. + + If token-based post-invocation compaction is triggered, this keeps the last N + raw events un-compacted. + """ + + @model_validator(mode="after") + def _validate_token_params(self) -> EventsCompactionConfig: + token_threshold_set = self.token_threshold is not None + retention_size_set = self.event_retention_size is not None + if token_threshold_set != retention_size_set: + raise ValueError( + "token_threshold and event_retention_size must be set together." + ) + return self + + class App(BaseModel): """Represents an LLM-backed agentic application. diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py index af7912e617..5e56ffc737 100644 --- a/src/google/adk/artifacts/__init__.py +++ b/src/google/adk/artifacts/__init__.py @@ -12,17 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING - from .base_artifact_service import BaseArtifactService - -if TYPE_CHECKING: - from .file_artifact_service import FileArtifactService - from .gcs_artifact_service import GcsArtifactService - from .in_memory_artifact_service import InMemoryArtifactService +from .file_artifact_service import FileArtifactService +from .gcs_artifact_service import GcsArtifactService +from .in_memory_artifact_service import InMemoryArtifactService __all__ = [ 'BaseArtifactService', @@ -30,16 +23,3 @@ 'GcsArtifactService', 'InMemoryArtifactService', ] - -_LAZY_MEMBERS: dict[str, str] = { - 'FileArtifactService': 'file_artifact_service', - 'GcsArtifactService': 'gcs_artifact_service', - 'InMemoryArtifactService': 'in_memory_artifact_service', -} - - -def __getattr__(name: str): - if name in _LAZY_MEMBERS: - module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') - return vars(module)[name] - raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index cc3fc9e6fa..932a265ed1 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,6 +22,7 @@ from . import _nl_planning from . import _output_schema_processor from . import basic +from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -35,7 +36,6 @@ def _create_request_processors(): """Create the standard request processor list for a single-agent flow.""" - from . import compaction from ...auth import auth_preprocessor return [ diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index d40f3bf7d9..c47fb8ec40 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -11,35 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING +import logging from .base_memory_service import BaseMemoryService +from .in_memory_memory_service import InMemoryMemoryService +from .vertex_ai_memory_bank_service import VertexAiMemoryBankService -if TYPE_CHECKING: - from .in_memory_memory_service import InMemoryMemoryService - from .vertex_ai_memory_bank_service import VertexAiMemoryBankService - from .vertex_ai_rag_memory_service import VertexAiRagMemoryService +logger = logging.getLogger('google_adk.' + __name__) __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', 'VertexAiMemoryBankService', - 'VertexAiRagMemoryService', ] -_LAZY_MEMBERS: dict[str, str] = { - 'InMemoryMemoryService': 'in_memory_memory_service', - 'VertexAiMemoryBankService': 'vertex_ai_memory_bank_service', - 'VertexAiRagMemoryService': 'vertex_ai_rag_memory_service', -} - +try: + from .vertex_ai_rag_memory_service import VertexAiRagMemoryService -def __getattr__(name: str): - if name in _LAZY_MEMBERS: - module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') - return vars(module)[name] - raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + __all__.append('VertexAiRagMemoryService') +except ImportError: + logger.debug( + 'The Vertex SDK is not installed. If you want to use the' + ' VertexAiRagMemoryService please install it. If not, you can ignore this' + ' warning.' + ) diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index 70347fd25e..45caf16038 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -1,7 +1,8 @@ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may in obtain a copy of the License at +# you may not use this file except in compliance with the License. +# You may in obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -11,18 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING - from .base_plugin import BasePlugin +from .debug_logging_plugin import DebugLoggingPlugin +from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager - -if TYPE_CHECKING: - from .debug_logging_plugin import DebugLoggingPlugin - from .logging_plugin import LoggingPlugin - from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin +from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin __all__ = [ 'BasePlugin', @@ -31,16 +25,3 @@ 'PluginManager', 'ReflectAndRetryToolPlugin', ] - -_LAZY_MEMBERS: dict[str, str] = { - 'DebugLoggingPlugin': 'debug_logging_plugin', - 'LoggingPlugin': 'logging_plugin', - 'ReflectAndRetryToolPlugin': 'reflect_retry_tool_plugin', -} - - -def __getattr__(name: str): - if name in _LAZY_MEMBERS: - module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') - return vars(module)[name] - raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 850c26bbba..f352e6eb5f 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -26,9 +26,9 @@ from typing import Generator from typing import List from typing import Optional -from typing import TYPE_CHECKING import warnings +from google.adk.apps.compaction import _run_compaction_for_sliding_window from google.genai import types from .agents.base_agent import BaseAgent @@ -38,7 +38,10 @@ from .agents.invocation_context import new_invocation_context_id from .agents.live_request_queue import LiveRequestQueue from .agents.run_config import RunConfig +from .apps.app import App +from .apps.app import ResumabilityConfig from .artifacts.base_artifact_service import BaseArtifactService +from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .errors.session_not_found_error import SessionNotFoundError @@ -48,21 +51,19 @@ from .flows.llm_flows.functions import find_event_by_function_call_id from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService +from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread from .plugins.base_plugin import BasePlugin from .plugins.plugin_manager import PluginManager from .sessions.base_session_service import BaseSessionService from .sessions.base_session_service import GetSessionConfig +from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry.tracing import tracer from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing -if TYPE_CHECKING: - from .apps.app import App - from .apps.app import ResumabilityConfig - logger = logging.getLogger('google_adk.' + __name__) @@ -619,8 +620,6 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: # the end of an invocation.) if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') - from google.adk.apps.compaction import _run_compaction_for_sliding_window - await _run_compaction_for_sliding_window( self.app, invocation_context.session, @@ -1678,10 +1677,6 @@ def __init__( app: Optional App instance. plugin_close_timeout: The timeout in seconds for plugin close methods. """ - from .artifacts.in_memory_artifact_service import InMemoryArtifactService - from .memory.in_memory_memory_service import InMemoryMemoryService - from .sessions.in_memory_session_service import InMemorySessionService - if app is None and app_name is None: app_name = 'InMemoryRunner' super().__init__( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index db983f96f1..7505eda346 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -11,20 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING - from .base_session_service import BaseSessionService +from .in_memory_session_service import InMemorySessionService from .session import Session from .state import State - -if TYPE_CHECKING: - from .database_session_service import DatabaseSessionService - from .in_memory_session_service import InMemorySessionService - from .vertex_ai_session_service import VertexAiSessionService +from .vertex_ai_session_service import VertexAiSessionService __all__ = [ 'BaseSessionService', @@ -35,23 +26,16 @@ 'VertexAiSessionService', ] -_LAZY_MEMBERS: dict[str, str] = { - 'InMemorySessionService': 'in_memory_session_service', - 'VertexAiSessionService': 'vertex_ai_session_service', -} - def __getattr__(name: str): - if name in _LAZY_MEMBERS: - module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') - return vars(module)[name] if name == 'DatabaseSessionService': try: - module = importlib.import_module(f'{__name__}.database_session_service') + from .database_session_service import DatabaseSessionService + + return DatabaseSessionService except ImportError as e: raise ImportError( - 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it' - ' is installed correctly.' + 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is' + ' installed correctly.' ) from e - return vars(module)['DatabaseSessionService'] raise AttributeError(f'module {__name__!r} has no attribute {name!r}') From bd062ec9eb4b48cc6d4ec45aaf0a1f8f847b6d7b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 15 May 2026 18:34:57 -0700 Subject: [PATCH 3/8] perf: lazy-load service registries and split apps.app to cut cold start ~8% PiperOrigin-RevId: 916278514 --- src/google/adk/agents/invocation_context.py | 4 +- src/google/adk/apps/__init__.py | 22 ++++- src/google/adk/apps/_configs.py | 95 +++++++++++++++++++ src/google/adk/apps/app.py | 81 ++-------------- src/google/adk/artifacts/__init__.py | 26 ++++- src/google/adk/flows/llm_flows/single_flow.py | 2 +- src/google/adk/memory/__init__.py | 34 ++++--- src/google/adk/plugins/__init__.py | 29 +++++- src/google/adk/runners.py | 17 ++-- src/google/adk/sessions/__init__.py | 30 ++++-- 10 files changed, 229 insertions(+), 111 deletions(-) create mode 100644 src/google/adk/apps/_configs.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7fdbaee89b..8dc1a01af2 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -25,8 +25,8 @@ from pydantic import Field from pydantic import PrivateAttr -from ..apps.app import EventsCompactionConfig -from ..apps.app import ResumabilityConfig +from ..apps._configs import EventsCompactionConfig +from ..apps._configs import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.auth_credential import AuthCredential from ..auth.credential_service.base_credential_service import BaseCredentialService diff --git a/src/google/adk/apps/__init__.py b/src/google/adk/apps/__init__.py index 3a5d0b0643..319293967b 100644 --- a/src/google/adk/apps/__init__.py +++ b/src/google/adk/apps/__init__.py @@ -12,10 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .app import App -from .app import ResumabilityConfig +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._configs import ResumabilityConfig + from .app import App __all__ = [ 'App', 'ResumabilityConfig', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'App': 'app', + 'ResumabilityConfig': '_configs', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/apps/_configs.py b/src/google/adk/apps/_configs.py new file mode 100644 index 0000000000..87f3666ebd --- /dev/null +++ b/src/google/adk/apps/_configs.py @@ -0,0 +1,95 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import model_validator + +from ..utils.feature_decorator import experimental +from .base_events_summarizer import BaseEventsSummarizer + + +@experimental +class ResumabilityConfig(BaseModel): + """The config of the resumability for an application. + + The "resumability" in ADK refers to the ability to: + 1. pause an invocation upon a long-running function call. + 2. resume an invocation from the last event, if it's paused or failed midway + through. + + Note: ADK resumes the invocation in a best-effort manner: + 1. Tool call to resume needs to be idempotent because we only guarantee + an at-least-once behavior once resumed. + 2. Any temporary / in-memory state will be lost upon resumption. + """ + + is_resumable: bool = False + """Whether the app supports agent resumption. + If enabled, the feature will be enabled for all agents in the app. + """ + + +@experimental +class EventsCompactionConfig(BaseModel): + """The config of event compaction for an application.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + summarizer: Optional[BaseEventsSummarizer] = None + """The event summarizer to use for compaction.""" + + compaction_interval: int + """The number of *new* user-initiated invocations that, once + fully represented in the session's events, will trigger a compaction.""" + + overlap_size: int + """The number of preceding invocations to include from the + end of the last compacted range. This creates an overlap between consecutive + compacted summaries, maintaining context.""" + + token_threshold: Optional[int] = Field( + default=None, + gt=0, + ) + """Post-invocation token threshold trigger. + + If set, ADK will attempt a post-invocation compaction when the most recently + observed prompt token count meets or exceeds this threshold. + """ + + event_retention_size: Optional[int] = Field(default=None, ge=0) + """Post-invocation raw event retention size. + + If token-based post-invocation compaction is triggered, this keeps the last N + raw events un-compacted. + """ + + @model_validator(mode="after") + def _validate_token_params(self) -> EventsCompactionConfig: + token_threshold_set = self.token_threshold is not None + retention_size_set = self.event_retention_size is not None + if token_threshold_set != retention_size_set: + raise ValueError( + "token_threshold and event_retention_size must be set together." + ) + return self diff --git a/src/google/adk/apps/app.py b/src/google/adk/apps/app.py index c20d581d9b..9bde128b7a 100644 --- a/src/google/adk/apps/app.py +++ b/src/google/adk/apps/app.py @@ -22,9 +22,16 @@ from ..agents.base_agent import BaseAgent from ..agents.context_cache_config import ContextCacheConfig -from ..apps.base_events_summarizer import BaseEventsSummarizer from ..plugins.base_plugin import BasePlugin -from ..utils.feature_decorator import experimental +from ._configs import EventsCompactionConfig +from ._configs import ResumabilityConfig + +__all__ = [ + "App", + "EventsCompactionConfig", + "ResumabilityConfig", + "validate_app_name", +] def validate_app_name(name: str) -> None: @@ -38,76 +45,6 @@ def validate_app_name(name: str) -> None: raise ValueError("App name cannot be 'user'; reserved for end-user input.") -@experimental -class ResumabilityConfig(BaseModel): - """The config of the resumability for an application. - - The "resumability" in ADK refers to the ability to: - 1. pause an invocation upon a long-running function call. - 2. resume an invocation from the last event, if it's paused or failed midway - through. - - Note: ADK resumes the invocation in a best-effort manner: - 1. Tool call to resume needs to be idempotent because we only guarantee - an at-least-once behavior once resumed. - 2. Any temporary / in-memory state will be lost upon resumption. - """ - - is_resumable: bool = False - """Whether the app supports agent resumption. - If enabled, the feature will be enabled for all agents in the app. - """ - - -@experimental -class EventsCompactionConfig(BaseModel): - """The config of event compaction for an application.""" - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - ) - - summarizer: Optional[BaseEventsSummarizer] = None - """The event summarizer to use for compaction.""" - - compaction_interval: int - """The number of *new* user-initiated invocations that, once - fully represented in the session's events, will trigger a compaction.""" - - overlap_size: int - """The number of preceding invocations to include from the - end of the last compacted range. This creates an overlap between consecutive - compacted summaries, maintaining context.""" - - token_threshold: Optional[int] = Field( - default=None, - gt=0, - ) - """Post-invocation token threshold trigger. - - If set, ADK will attempt a post-invocation compaction when the most recently - observed prompt token count meets or exceeds this threshold. - """ - - event_retention_size: Optional[int] = Field(default=None, ge=0) - """Post-invocation raw event retention size. - - If token-based post-invocation compaction is triggered, this keeps the last N - raw events un-compacted. - """ - - @model_validator(mode="after") - def _validate_token_params(self) -> EventsCompactionConfig: - token_threshold_set = self.token_threshold is not None - retention_size_set = self.event_retention_size is not None - if token_threshold_set != retention_size_set: - raise ValueError( - "token_threshold and event_retention_size must be set together." - ) - return self - - class App(BaseModel): """Represents an LLM-backed agentic application. diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py index 5e56ffc737..af7912e617 100644 --- a/src/google/adk/artifacts/__init__.py +++ b/src/google/adk/artifacts/__init__.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_artifact_service import BaseArtifactService -from .file_artifact_service import FileArtifactService -from .gcs_artifact_service import GcsArtifactService -from .in_memory_artifact_service import InMemoryArtifactService + +if TYPE_CHECKING: + from .file_artifact_service import FileArtifactService + from .gcs_artifact_service import GcsArtifactService + from .in_memory_artifact_service import InMemoryArtifactService __all__ = [ 'BaseArtifactService', @@ -23,3 +30,16 @@ 'GcsArtifactService', 'InMemoryArtifactService', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'FileArtifactService': 'file_artifact_service', + 'GcsArtifactService': 'gcs_artifact_service', + 'InMemoryArtifactService': 'in_memory_artifact_service', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 932a265ed1..cc3fc9e6fa 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,7 +22,6 @@ from . import _nl_planning from . import _output_schema_processor from . import basic -from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -36,6 +35,7 @@ def _create_request_processors(): """Create the standard request processor list for a single-agent flow.""" + from . import compaction from ...auth import auth_preprocessor return [ diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..d40f3bf7d9 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -11,27 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING from .base_memory_service import BaseMemoryService -from .in_memory_memory_service import InMemoryMemoryService -from .vertex_ai_memory_bank_service import VertexAiMemoryBankService -logger = logging.getLogger('google_adk.' + __name__) +if TYPE_CHECKING: + from .in_memory_memory_service import InMemoryMemoryService + from .vertex_ai_memory_bank_service import VertexAiMemoryBankService + from .vertex_ai_rag_memory_service import VertexAiRagMemoryService __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', 'VertexAiMemoryBankService', + 'VertexAiRagMemoryService', ] -try: - from .vertex_ai_rag_memory_service import VertexAiRagMemoryService +_LAZY_MEMBERS: dict[str, str] = { + 'InMemoryMemoryService': 'in_memory_memory_service', + 'VertexAiMemoryBankService': 'vertex_ai_memory_bank_service', + 'VertexAiRagMemoryService': 'vertex_ai_rag_memory_service', +} + - __all__.append('VertexAiRagMemoryService') -except ImportError: - logger.debug( - 'The Vertex SDK is not installed. If you want to use the' - ' VertexAiRagMemoryService please install it. If not, you can ignore this' - ' warning.' - ) +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index 45caf16038..70347fd25e 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -1,8 +1,7 @@ # Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may in obtain a copy of the License at +# you may in obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -12,11 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_plugin import BasePlugin -from .debug_logging_plugin import DebugLoggingPlugin -from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager -from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin + +if TYPE_CHECKING: + from .debug_logging_plugin import DebugLoggingPlugin + from .logging_plugin import LoggingPlugin + from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin __all__ = [ 'BasePlugin', @@ -25,3 +31,16 @@ 'PluginManager', 'ReflectAndRetryToolPlugin', ] + +_LAZY_MEMBERS: dict[str, str] = { + 'DebugLoggingPlugin': 'debug_logging_plugin', + 'LoggingPlugin': 'logging_plugin', + 'ReflectAndRetryToolPlugin': 'reflect_retry_tool_plugin', +} + + +def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index f352e6eb5f..850c26bbba 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -26,9 +26,9 @@ from typing import Generator from typing import List from typing import Optional +from typing import TYPE_CHECKING import warnings -from google.adk.apps.compaction import _run_compaction_for_sliding_window from google.genai import types from .agents.base_agent import BaseAgent @@ -38,10 +38,7 @@ from .agents.invocation_context import new_invocation_context_id from .agents.live_request_queue import LiveRequestQueue from .agents.run_config import RunConfig -from .apps.app import App -from .apps.app import ResumabilityConfig from .artifacts.base_artifact_service import BaseArtifactService -from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .errors.session_not_found_error import SessionNotFoundError @@ -51,19 +48,21 @@ from .flows.llm_flows.functions import find_event_by_function_call_id from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService -from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread from .plugins.base_plugin import BasePlugin from .plugins.plugin_manager import PluginManager from .sessions.base_session_service import BaseSessionService from .sessions.base_session_service import GetSessionConfig -from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry.tracing import tracer from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing +if TYPE_CHECKING: + from .apps.app import App + from .apps.app import ResumabilityConfig + logger = logging.getLogger('google_adk.' + __name__) @@ -620,6 +619,8 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: # the end of an invocation.) if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') + from google.adk.apps.compaction import _run_compaction_for_sliding_window + await _run_compaction_for_sliding_window( self.app, invocation_context.session, @@ -1677,6 +1678,10 @@ def __init__( app: Optional App instance. plugin_close_timeout: The timeout in seconds for plugin close methods. """ + from .artifacts.in_memory_artifact_service import InMemoryArtifactService + from .memory.in_memory_memory_service import InMemoryMemoryService + from .sessions.in_memory_session_service import InMemorySessionService + if app is None and app_name is None: app_name = 'InMemoryRunner' super().__init__( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 7505eda346..db983f96f1 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -11,11 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + from .base_session_service import BaseSessionService -from .in_memory_session_service import InMemorySessionService from .session import Session from .state import State -from .vertex_ai_session_service import VertexAiSessionService + +if TYPE_CHECKING: + from .database_session_service import DatabaseSessionService + from .in_memory_session_service import InMemorySessionService + from .vertex_ai_session_service import VertexAiSessionService __all__ = [ 'BaseSessionService', @@ -26,16 +35,23 @@ 'VertexAiSessionService', ] +_LAZY_MEMBERS: dict[str, str] = { + 'InMemorySessionService': 'in_memory_session_service', + 'VertexAiSessionService': 'vertex_ai_session_service', +} + def __getattr__(name: str): + if name in _LAZY_MEMBERS: + module = importlib.import_module(f'{__name__}.{_LAZY_MEMBERS[name]}') + return vars(module)[name] if name == 'DatabaseSessionService': try: - from .database_session_service import DatabaseSessionService - - return DatabaseSessionService + module = importlib.import_module(f'{__name__}.database_session_service') except ImportError as e: raise ImportError( - 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is' - ' installed correctly.' + 'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it' + ' is installed correctly.' ) from e + return vars(module)['DatabaseSessionService'] raise AttributeError(f'module {__name__!r} has no attribute {name!r}') From 12844939f1a89b2a06c592a52bbd3c293860e808 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 18 May 2026 02:36:12 -0700 Subject: [PATCH 4/8] fix: only serialize llm_response to json if it will be included in the trace PiperOrigin-RevId: 917108173 --- src/google/adk/telemetry/tracing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 84a061a48b..32040b8bc6 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -358,12 +358,12 @@ def trace_call_llm( except AttributeError: pass - try: - llm_response_json = llm_response.model_dump_json(exclude_none=True) - except Exception: # pylint: disable=broad-exception-caught - llm_response_json = '' - if _should_add_request_response_to_spans(): + try: + llm_response_json = llm_response.model_dump_json(exclude_none=True) + except Exception: # pylint: disable=broad-exception-caught + llm_response_json = '' + span.set_attribute( 'gcp.vertex.agent.llm_response', llm_response_json, From 3d07960a70031fb7786485f58a964a98dbdb932d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jorge=20Israel=20Fr=C3=B3meta=20Moya?= Date: Mon, 18 May 2026 07:53:26 -0700 Subject: [PATCH 5/8] fix: use tool_responses role for gemma4 models in LiteLLM integration Merge https://github.com/google/adk-python/pull/5655 Closes: #5650 Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/5655 from jfrometa88:main d74b5216b0937ff371b25cc8f272ee8ab33c8dc9 PiperOrigin-RevId: 917231422 --- src/google/adk/models/lite_llm.py | 27 ++- .../models/test_lite_llm_gemma_tool_role.py | 177 ++++++++++++++++++ 2 files changed, 197 insertions(+), 7 deletions(-) create mode 100644 tests/unittests/models/test_lite_llm_gemma_tool_role.py diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 3a6c36624d..d5c400ea0f 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -807,9 +807,14 @@ async def _content_to_message_param( if isinstance(response, str) else _safe_json_serialize(response) ) + # gemma4 requires role='tool_responses' for recognizing function_response parts as responses + # from the tool call, instead of OpenAI-compatible 'tool' role used by other models. + # Earlier Gemma versions before version 4 do not support tool use, + # so this check is intentionally scoped to only look for "gemma4" in the model name. + tool_role = "tool_responses" if "gemma4" in model.lower() else "tool" tool_messages.append( ChatCompletionToolMessage( - role="tool", + role=tool_role, tool_call_id=part.function_response.id, content=response_content, ) @@ -824,6 +829,7 @@ async def _content_to_message_param( follow_up = await _content_to_message_param( types.Content(role=content.role, parts=non_tool_parts), provider=provider, + model=model, ) follow_up_messages = ( follow_up if isinstance(follow_up, list) else [follow_up] @@ -934,12 +940,16 @@ async def _content_to_message_param( ) -def _ensure_tool_results(messages: List[Message]) -> List[Message]: +def _ensure_tool_results(messages: List[Message], model: str) -> List[Message]: """Insert placeholder tool messages for missing tool results. LiteLLM-backed providers like OpenAI and Anthropic reject histories where an assistant tool call is not followed by tool responses before the next non-tool message. This helps recover from interrupted tool execution. + + For models that expect a different tool response role (e.g. Gemma4 models, + which require 'tool_responses' instead of 'tool'), the role is adjusted + accordingly. """ if not messages: return messages @@ -948,17 +958,19 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: healed_messages: List[Message] = [] pending_tool_call_ids: List[str] = [] + expected_tool_role = "tool_responses" if "gemma4" in model.lower() else "tool" for message in messages: role = message.get("role") - if pending_tool_call_ids and role != "tool": + + if pending_tool_call_ids and role != expected_tool_role: logger.warning( "Missing tool results for tool_call_id(s): %s", pending_tool_call_ids, ) healed_messages.extend( ChatCompletionToolMessage( - role="tool", + role=expected_tool_role, tool_call_id=tool_call_id, content=_MISSING_TOOL_RESULT_MESSAGE, ) @@ -971,13 +983,14 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: pending_tool_call_ids = [ tool_call.get("id") for tool_call in tool_calls if tool_call.get("id") ] - elif role == "tool": + elif role == expected_tool_role: tool_call_id = message.get("tool_call_id") if tool_call_id in pending_tool_call_ids: pending_tool_call_ids.remove(tool_call_id) healed_messages.append(message) + # Final block also uses expected_tool_role if pending_tool_call_ids: logger.warning( "Missing tool results for tool_call_id(s): %s", @@ -985,7 +998,7 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: ) healed_messages.extend( ChatCompletionToolMessage( - role="tool", + role=expected_tool_role, tool_call_id=tool_call_id, content=_MISSING_TOOL_RESULT_MESSAGE, ) @@ -1905,7 +1918,7 @@ async def _get_completion_inputs( content=llm_request.config.system_instruction, ), ) - messages = _ensure_tool_results(messages) + messages = _ensure_tool_results(messages, model) # 2. Convert tool declarations tools: Optional[List[Dict]] = None diff --git a/tests/unittests/models/test_lite_llm_gemma_tool_role.py b/tests/unittests/models/test_lite_llm_gemma_tool_role.py new file mode 100644 index 0000000000..901978dff8 --- /dev/null +++ b/tests/unittests/models/test_lite_llm_gemma_tool_role.py @@ -0,0 +1,177 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Gemma-specific tool role handling in _content_to_message_param. + +Gemma's chat template expects role='tool_responses' for tool result messages, +while the OpenAI-compatible default is role='tool'. This module verifies that +_content_to_message_param sets the correct role based on the model name. +""" + +from typing import Any + +from google.adk.models.lite_llm import _content_to_message_param +from google.genai import types +import pytest + + +def _make_function_response_content( + function_name: str = "get_weather", + response_data: dict[str, Any] | None = None, + call_id: str = "call_001", +) -> types.Content: + """Builds a types.Content with a single function_response part.""" + if response_data is None: + response_data = {"city": "Santiago de Cuba", "condition": "sunny"} + return types.Content( + role="user", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=function_name, + response=response_data, + id=call_id, + ) + ) + ], + ) + + +def _make_multi_function_response_content( + call_ids: list[str] | None = None, +) -> types.Content: + """Builds a types.Content with multiple function_response parts.""" + if call_ids is None: + call_ids = ["call_001", "call_002"] + return types.Content( + role="user", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=f"tool_{i}", + response={"result": f"value_{i}"}, + id=call_id, + ) + ) + for i, call_id in enumerate(call_ids) + ], + ) + + +def _extract_role(msg) -> str: + """Extracts role from a litellm message, whether dict or object.""" + if isinstance(msg, dict): + return msg["role"] + return msg.role + + +class TestToolRoleSingleResponse: + """_content_to_message_param with a single function_response part.""" + + @pytest.mark.asyncio + async def test_gemma4_model_uses_tool_responses_role(self): + """Models containing 'gemma4' should get role='tool_responses'.""" + content = _make_function_response_content() + + result = await _content_to_message_param(content, model="ollama/gemma4:e2b") + + assert _extract_role(result) == "tool_responses", ( + "Gemma models require role='tool_responses' to match their chat " + "template; role='tool' causes infinite tool-calling loops." + ) + + @pytest.mark.asyncio + async def test_gemma4_uppercase_model_name(self): + """Model name matching should be case-insensitive.""" + content = _make_function_response_content() + + result = await _content_to_message_param(content, model="ollama/Gemma4:31b") + + assert _extract_role(result) == "tool_responses" + + @pytest.mark.asyncio + async def test_tool_call_id_and_content_preserved(self): + """Fix must not alter tool_call_id or content — only role changes.""" + content = _make_function_response_content( + response_data={"status": "ok"}, call_id="my_call_123" + ) + + result = await _content_to_message_param(content, model="ollama/gemma4:e2b") + + if isinstance(result, dict): + assert result["tool_call_id"] == "my_call_123" + assert "ok" in result["content"] + else: + assert result.tool_call_id == "my_call_123" + assert "ok" in result.content + + @pytest.mark.asyncio + async def test_empty_model_string_uses_tool_role(self): + """Empty model string should fall back to default role='tool'.""" + content = _make_function_response_content() + + result = await _content_to_message_param(content, model="") + + assert _extract_role(result) == "tool" + + @pytest.mark.asyncio + async def test_unrelated_models_use_tool_role(self): + """Models that do not contain 'gemma4' must not be affected.""" + unaffected_models = [ + "ollama/llama3:8b", + "ollama/qwen2.5-coder:3b", + "anthropic/claude-3-opus", + "openai/gpt-4o", + "ollama/gemma3:4b", # gemma3 != gemma4 + ] + for model in unaffected_models: + content = _make_function_response_content() + result = await _content_to_message_param(content, model=model) + assert ( + _extract_role(result) == "tool" + ), f"Model '{model}' should not be affected by the Gemma4 fix." + + +class TestToolRoleMultipleResponses: + """_content_to_message_param with multiple function_response parts.""" + + @pytest.mark.asyncio + async def test_gemma4_all_messages_use_tool_responses_role(self): + """All messages in a multi-response must have role='tool_responses'.""" + content = _make_multi_function_response_content( + call_ids=["call_a", "call_b", "call_c"] + ) + + result = await _content_to_message_param(content, model="ollama/gemma4:4b") + + assert isinstance(result, list) + assert len(result) == 3 + for msg in result: + assert _extract_role(msg) == "tool_responses", ( + "Every tool message in a multi-response must use 'tool_responses' " + "for Gemma4 models." + ) + + @pytest.mark.asyncio + async def test_non_gemma_multi_response_uses_tool_role(self): + """Non-Gemma multi-response messages should all have role='tool'.""" + content = _make_multi_function_response_content( + call_ids=["call_a", "call_b"] + ) + + result = await _content_to_message_param(content, model="openai/gpt-4o") + + assert isinstance(result, list) + for msg in result: + assert _extract_role(msg) == "tool" From 59f7347a635bc56fa8abdd3c7c771ae11bebf9ab Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 18 May 2026 09:28:05 -0700 Subject: [PATCH 6/8] fix(small): Convert events to the A2A format while respecting user vs agent role PiperOrigin-RevId: 917275872 --- .../adk/a2a/converters/event_converter.py | 5 ++- .../a2a/converters/test_event_converter.py | 37 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index e6a890941f..7ebd9f6d1c 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -570,7 +570,10 @@ def convert_event_to_a2a_events( # Handle regular message content message = convert_event_to_a2a_message( - event, invocation_context, part_converter=part_converter + event, + invocation_context, + part_converter=part_converter, + role=Role.user if event.author == "user" else Role.agent, ) if message: running_event = _create_status_update_event( diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index 61f8c3aca6..e850b0123b 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -33,6 +33,7 @@ from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX from google.adk.agents.invocation_context import InvocationContext from google.adk.events.event import Event @@ -438,6 +439,42 @@ def test_convert_event_to_a2a_events_with_custom_ids(self): context_id, ) + def test_convert_event_to_a2a_events_user_role(self): + """Test event to A2A events conversion with events from a user.""" + # Setup message + mock_message = Mock(spec=Message) + mock_message.parts = [] + + with patch( + "google.adk.a2a.converters.event_converter.convert_event_to_a2a_message" + ) as mock_convert_message: + mock_convert_message.return_value = mock_message + + with patch( + "google.adk.a2a.converters.event_converter._create_status_update_event" + ) as mock_create_running: + mock_running_event = Mock() + mock_create_running.return_value = mock_running_event + self.mock_event.author = "user" + + task_id = "custom-task-id" + context_id = "custom-context-id" + + result = convert_event_to_a2a_events( + self.mock_event, self.mock_invocation_context, task_id, context_id + ) + + assert len(result) == 1 + assert result[0] == mock_running_event + + # Verify the function is called with the specific task_id and context_id + mock_convert_message.assert_called_once_with( + self.mock_event, + self.mock_invocation_context, + part_converter=convert_genai_part_to_a2a_part, + role=Role.user, + ) + def test_create_status_update_event_with_auth_required_state(self): """Test creation of status update event with auth_required state.""" from a2a.types import DataPart From 48f1b302510c3520643db739494ff8ea318b7b8f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 18 May 2026 11:03:31 -0700 Subject: [PATCH 7/8] feat: Simplify data retrieved handling of ask_data_agent tool and ask_data_insights tool PiperOrigin-RevId: 917326615 --- src/google/adk/tools/_gda_stream_util.py | 144 +++++++++ .../adk/tools/bigquery/data_insights_tool.py | 220 ++----------- .../adk/tools/data_agent/data_agent_tool.py | 260 +++------------- .../test_bigquery_data_insights_tool.py | 185 ++--------- ...k_data_insights_penguins_highest_mass.yaml | 293 ++---------------- .../tools/data_agent/test_data_agent_tool.py | 119 +++---- .../unittests/tools/test__gda_stream_util.py | 163 ++++++++++ 7 files changed, 460 insertions(+), 924 deletions(-) create mode 100644 src/google/adk/tools/_gda_stream_util.py create mode 100644 tests/unittests/tools/test__gda_stream_util.py diff --git a/src/google/adk/tools/_gda_stream_util.py b/src/google/adk/tools/_gda_stream_util.py new file mode 100644 index 0000000000..b8a6863168 --- /dev/null +++ b/src/google/adk/tools/_gda_stream_util.py @@ -0,0 +1,144 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +from typing import Any + +import requests + + +def get_stream( + url: str, + ca_payload: dict[str, Any], + headers: dict[str, str], + max_query_result_rows: int, +) -> list[dict[str, Any]]: + """Sends a JSON request to a streaming API and returns a list of messages.""" + with requests.Session() as s: + accumulator = "" + messages = [] + data_msg_idx = -1 + + with s.post(url, json=ca_payload, headers=headers, stream=True) as resp: + resp.raise_for_status() + for line in resp.iter_lines(): + if not line: + continue + + decoded_line = line.decode("utf-8") + + if decoded_line == "[{": + accumulator = "{" + elif decoded_line == "}]": + accumulator += "}" + elif decoded_line == ",": + continue + else: + accumulator += decoded_line + + try: + data_json = json.loads(accumulator) + except ValueError: + continue + + accumulator = "" + + if not isinstance(data_json, dict): + messages.append(data_json) + continue + + processed_msg = None + data_result = _extract_data_result(data_json) + if data_result is not None: + processed_msg = _format_data_retrieved( + data_result, max_query_result_rows + ) + if data_msg_idx >= 0: + messages[data_msg_idx] = { + "Data Retrieved": "Intermediate result omitted" + } + data_msg_idx = len(messages) + elif isinstance(data_json.get("systemMessage"), dict): + processed_msg = data_json["systemMessage"] + else: + processed_msg = data_json + + if processed_msg is not None: + messages.append(processed_msg) + + return messages + + +def _extract_data_result(msg: dict[str, Any]) -> dict[str, Any] | None: + """Attempts to find the result.data deep inside the generic dict.""" + sm = msg.get("systemMessage") + if not isinstance(sm, dict): + return None + data = sm.get("data") + if not isinstance(data, dict): + return None + result = data.get("result") + if not isinstance(result, dict): + return None + if "data" in result and isinstance(result["data"], list): + return result + return None + + +def _format_data_retrieved( + result: dict[str, Any], max_rows: int +) -> dict[str, Any]: + """Transforms the raw result dict into the simplified Toolbox format.""" + raw_data = result.get("data", []) + + fields = [] + schema = result.get("schema") + if isinstance(schema, dict): + schema_fields = schema.get("fields") + if isinstance(schema_fields, list): + fields = schema_fields + + headers = [] + for f in fields: + if isinstance(f, dict): + name = f.get("name") + if isinstance(name, str): + headers.append(name) + + if not headers and raw_data: + first_row = raw_data[0] + if isinstance(first_row, dict): + headers = list(first_row.keys()) + + total_rows = len(raw_data) + num_to_display = min(total_rows, max_rows) + + rows = [] + for r in raw_data[:num_to_display]: + if isinstance(r, dict): + row = [r.get(h) for h in headers] + rows.append(row) + + summary = f"Showing all {total_rows} rows." + if total_rows > max_rows: + summary = f"Showing the first {num_to_display} of {total_rows} total rows." + + return { + "Data Retrieved": { + "headers": headers, + "rows": rows, + "summary": summary, + } + } diff --git a/src/google/adk/tools/bigquery/data_insights_tool.py b/src/google/adk/tools/bigquery/data_insights_tool.py index ba30eb5dbd..45f45d874f 100644 --- a/src/google/adk/tools/bigquery/data_insights_tool.py +++ b/src/google/adk/tools/bigquery/data_insights_tool.py @@ -23,6 +23,7 @@ import requests from . import client +from .. import _gda_stream_util from .config import BigQueryToolConfig _GDA_CLIENT_ID = "GOOGLE_ADK" @@ -66,8 +67,9 @@ def ask_data_insights( A dictionary with two keys: - 'status': A string indicating the final status (e.g., "SUCCESS"). - 'response': A list of dictionaries, where each dictionary - represents a step in the API's execution process (e.g., SQL - generation, data retrieval, final answer). + represents a step in the agent's execution process and can + contain keys like 'text', 'data', or 'Data Retrieved' indicating + thought process, SQL generation, data retrieval, or final answer. Example: A query joining multiple tables, showing the full return structure. @@ -99,7 +101,18 @@ def ask_data_insights( "status": "SUCCESS", "response": [ { - "SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... " + "text": { + "parts": [ + "Analyzing context", + "Retrieved context for 2 tables." + ], + "textType": "THOUGHT" + } + }, + { + "data": { + "generatedSql": "SELECT t1.customer_name, SUM(t2.order_total) ..." + } }, { "Data Retrieved": { @@ -109,7 +122,12 @@ def ask_data_insights( } }, { - "Answer": "The customer who spent the most was Jane Doe." + "text": { + "parts": [ + "The customer who spent the most was Jane Doe." + ], + "textType": "FINAL_RESPONSE" + } } ] } @@ -155,7 +173,7 @@ def ask_data_insights( "clientIdEnum": _GDA_CLIENT_ID, } - resp = _get_stream( + resp = _gda_stream_util.get_stream( ca_url, ca_payload, headers, settings.max_query_result_rows ) except Exception as ex: # pylint: disable=broad-except @@ -163,195 +181,5 @@ def ask_data_insights( "status": "ERROR", "error_details": str(ex), } - return {"status": "SUCCESS", "response": resp} - - -def _get_stream( - url: str, - ca_payload: Dict[str, Any], - headers: Dict[str, str], - max_query_result_rows: int, -) -> List[Dict[str, Any]]: - """Sends a JSON request to a streaming API and returns a list of messages.""" - s = requests.Session() - - accumulator = "" - messages = [] - - with s.post(url, json=ca_payload, headers=headers, stream=True) as resp: - for line in resp.iter_lines(): - if not line: - continue - - decoded_line = str(line, encoding="utf-8") - - if decoded_line == "[{": - accumulator = "{" - elif decoded_line == "}]": - accumulator += "}" - elif decoded_line == ",": - continue - else: - accumulator += decoded_line - - if not _is_json(accumulator): - continue - - data_json = json.loads(accumulator) - if "systemMessage" not in data_json: - if "error" in data_json: - _append_message(messages, _handle_error(data_json["error"])) - continue - - system_message = data_json["systemMessage"] - if "text" in system_message: - _append_message(messages, _handle_text_response(system_message["text"])) - elif "schema" in system_message: - _append_message( - messages, - _handle_schema_response(system_message["schema"]), - ) - elif "data" in system_message: - _append_message( - messages, - _handle_data_response( - system_message["data"], max_query_result_rows - ), - ) - accumulator = "" - return messages - - -def _is_json(s: str) -> bool: - """Checks if a string is a valid JSON object.""" - try: - json.loads(s) - except ValueError: - return False - return True - - -def _get_property( - data: Dict[str, Any], field_name: str, default: Any = "" -) -> Any: - """Safely gets a property from a dictionary.""" - return data.get(field_name, default) - - -def _format_bq_table_ref(table_ref: Dict[str, str]) -> str: - """Formats a BigQuery table reference dictionary into a string.""" - return f"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}" - - -def _format_schema_as_dict( - data: Dict[str, Any], -) -> Dict[str, List[Any]]: - """Extracts schema fields into a dictionary.""" - fields = data.get("fields", []) - if not fields: - return {"columns": []} - - column_details = [] - headers = ["Column", "Type", "Description", "Mode"] - rows: List[List[str, str, str, str]] = [] - for field in fields: - row_list = [ - _get_property(field, "name"), - _get_property(field, "type"), - _get_property(field, "description", ""), - _get_property(field, "mode"), - ] - rows.append(row_list) - - return {"headers": headers, "rows": rows} - -def _format_datasource_as_dict(datasource: Dict[str, Any]) -> Dict[str, Any]: - """Formats a full datasource object into a dictionary with its name and schema.""" - source_name = _format_bq_table_ref(datasource["bigqueryTableReference"]) - - schema = _format_schema_as_dict(datasource["schema"]) - return {"source_name": source_name, "schema": schema} - - -def _handle_text_response(resp: Dict[str, Any]) -> Dict[str, str]: - """Formats a text response into a dictionary.""" - parts = resp.get("parts", []) - return {"Answer": "".join(parts)} - - -def _handle_schema_response(resp: Dict[str, Any]) -> Dict[str, Any]: - """Formats a schema response into a dictionary.""" - if "query" in resp: - return {"Question": resp["query"].get("question", "")} - elif "result" in resp: - datasources = resp["result"].get("datasources", []) - # Format each datasource and join them with newlines - formatted_sources = [_format_datasource_as_dict(ds) for ds in datasources] - return {"Schema Resolved": formatted_sources} - return {} - - -def _handle_data_response( - resp: Dict[str, Any], max_query_result_rows: int -) -> Dict[str, Any]: - """Formats a data response into a dictionary.""" - if "query" in resp: - query = resp["query"] - return { - "Retrieval Query": { - "Query Name": query.get("name", "N/A"), - "Question": query.get("question", "N/A"), - } - } - elif "generatedSql" in resp: - return {"SQL Generated": resp["generatedSql"]} - elif "result" in resp: - schema = resp["result"]["schema"] - headers = [field.get("name") for field in schema.get("fields", [])] - - all_rows = resp["result"].get("data", []) - total_rows = len(all_rows) - - compact_rows = [] - for row_dict in all_rows[:max_query_result_rows]: - row_values = [row_dict.get(header) for header in headers] - compact_rows.append(row_values) - - summary_string = f"Showing all {total_rows} rows." - if total_rows > max_query_result_rows: - summary_string = ( - f"Showing the first {len(compact_rows)} of {total_rows} total rows." - ) - - return { - "Data Retrieved": { - "headers": headers, - "rows": compact_rows, - "summary": summary_string, - } - } - - return {} - - -def _handle_error(resp: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: - """Formats an error response into a dictionary.""" - return { - "Error": { - "Code": resp.get("code", "N/A"), - "Message": resp.get("message", "No message provided."), - } - } - - -def _append_message( - messages: List[Dict[str, Any]], new_message: Dict[str, Any] -): - if not new_message: - return - - if messages and ("Data Retrieved" in messages[-1]): - messages.pop() - - messages.append(new_message) + return {"status": "SUCCESS", "response": resp} diff --git a/src/google/adk/tools/data_agent/data_agent_tool.py b/src/google/adk/tools/data_agent/data_agent_tool.py index ca58eb7c46..007fe71158 100644 --- a/src/google/adk/tools/data_agent/data_agent_tool.py +++ b/src/google/adk/tools/data_agent/data_agent_tool.py @@ -19,6 +19,7 @@ from google.auth.credentials import Credentials import requests +from .. import _gda_stream_util from ..tool_context import ToolContext from .config import DataAgentToolConfig @@ -129,7 +130,7 @@ def list_accessible_data_agents( except Exception as ex: # pylint: disable=broad-except return { "status": "ERROR", - "error_details": repr(ex), + "error_details": str(ex), } @@ -196,7 +197,7 @@ def get_data_agent_info( except Exception as ex: # pylint: disable=broad-except return { "status": "ERROR", - "error_details": repr(ex), + "error_details": str(ex), } @@ -221,21 +222,19 @@ def ask_data_agent( A dictionary with two keys: - 'status': A string indicating the final status (e.g., "SUCCESS"). - 'response': A list of dictionaries, where each dictionary - represents a step in the agent's execution process (e.g., SQL - generation, data retrieval, final answer). Note that the 'Answer' - step contains a text response which may summarize findings or refer - to previous steps of agent execution, such as 'Data Retrieved', in - which cases, the 'Answer' step does not include the result data. + represents a step in the agent's execution process and can + contain keys like 'text', 'data', or 'Data Retrieved' indicating + thought process, SQL generation, data retrieval, or final answer. Examples: A query to a data agent, showing the full return structure. - The original question: "Which customer from New York spent the most last - month?" + The original question: "What is the average tree height in San + Francisco?" >>> ask_data_agent( ... - data_agent_name="projects/my-project/locations/global/dataAgents/sales-agent", - ... query="Which customer from New York spent the most last month?", + data_agent_name="projects/my-project/locations/global/dataAgents/sf-trees-agent", + ... query="What is the average tree height in San Francisco?", ... credentials=credentials, ... tool_context=tool_context, ... ) @@ -243,42 +242,39 @@ def ask_data_agent( "status": "SUCCESS", "response": [ { - "Question": "Which customer from New York spent the most last - month?" - }, - { - "Schema Resolved": [ - { - "source_name": "my-gcp-project.sales_data.customers", - "schema": { - "headers": ["Column", "Type", "Description", "Mode"], - "rows": [ - ["customer_id", "INT64", "Customer ID", "REQUIRED"], - ["customer_name", "STRING", "Customer Name", "NULLABLE"], - ] - } - } - ] - }, - { - "Retrieval Query": { - "Query Name": "top_spender", - "Question": "Find top spending customer from New York in the last - month." + "text": { + "parts": [ + "Analyzing context", + "Retrieved context for 1 table." + ], + "textType": "THOUGHT" } }, { - "SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... " + "data": { + "generatedSql": "SELECT\n AVG(SAFE_CAST(street_trees.dbh AS FLOAT64)) AS average_height\nFROM\n bigquery-public-data.san_francisco.street_trees AS street_trees;" + } }, { "Data Retrieved": { - "headers": ["customer_name", "total_spent"], - "rows": [["Jane Doe", 1234.56]], + "headers": [ + "average_height" + ], + "rows": [ + [ + 10.073475670972512 + ] + ], "summary": "Showing all 1 rows." } }, { - "Answer": "The customer who spent the most last month was Jane Doe." + "text": { + "parts": [ + "### Summary\nBased on the street tree data for San Francisco, the average height (recorded in the dbh column) is approximately 10.07." + ], + "textType": "FINAL_RESPONSE" + } } ] } @@ -298,196 +294,16 @@ def ask_data_agent( }, "clientIdEnum": _GDA_CLIENT_ID, } - resp = _get_stream( + resp = _gda_stream_util.get_stream( chat_url, chat_payload, - headers=headers, - max_query_result_rows=settings.max_query_result_rows, + headers, + settings.max_query_result_rows, ) + return {"status": "SUCCESS", "response": resp} except Exception as ex: # pylint: disable=broad-except return { "status": "ERROR", - "error_details": repr(ex), + "error_details": str(ex), } - - -def _get_stream( - url: str, - ca_payload: dict[str, Any], - *, - headers: dict[str, str], - max_query_result_rows: int, -) -> list[dict[str, Any]]: - """Sends a JSON request to a streaming API and returns a list of messages.""" - s = requests.Session() - - accumulator = "" - messages = [] - - with s.post(url, json=ca_payload, headers=headers, stream=True) as resp: - for line in resp.iter_lines(): - if not line: - continue - - decoded_line = str(line, encoding="utf-8") - - if decoded_line == "[{": - accumulator = "{" - elif decoded_line == "}]": - accumulator += "}" - elif decoded_line == ",": - continue - else: - accumulator += decoded_line - - try: - data_json = json.loads(accumulator) - except ValueError: - continue - if "systemMessage" not in data_json: - if "error" in data_json: - _append_message( - messages, - _handle_error(data_json["error"]), - ) - continue - - system_message = data_json["systemMessage"] - if "text" in system_message: - _append_message( - messages, - _handle_text_response(system_message["text"]), - ) - elif "schema" in system_message: - _append_message( - messages, - _handle_schema_response(system_message["schema"]), - ) - elif "data" in system_message: - _append_message( - messages, - _handle_data_response( - system_message["data"], max_query_result_rows - ), - ) - accumulator = "" - return messages - - -def _format_bq_table_ref(table_ref: dict[str, str]) -> str: - """Formats a BigQuery table reference dictionary into a string.""" - return f"{table_ref.get('projectId')}.{table_ref.get('datasetId')}.{table_ref.get('tableId')}" - - -def _format_schema_as_dict( - data: dict[str, Any], -) -> dict[str, list[Any]]: - """Extracts schema fields into a dictionary.""" - fields = data.get("fields", []) - if not fields: - return {"columns": []} - - column_details = [] - headers = ["Column", "Type", "Description", "Mode"] - rows: list[list[str, str, str, str]] = [] - for field in fields: - row_list = [ - field.get("name", ""), - field.get("type", ""), - field.get("description", ""), - field.get("mode", ""), - ] - rows.append(row_list) - - return {"headers": headers, "rows": rows} - - -def _format_datasource_as_dict(datasource: dict[str, Any]) -> dict[str, Any]: - """Formats a full datasource object into a dictionary with its name and schema.""" - source_name = _format_bq_table_ref(datasource["bigqueryTableReference"]) - - schema = _format_schema_as_dict(datasource["schema"]) - return {"source_name": source_name, "schema": schema} - - -def _handle_text_response(resp: dict[str, Any]) -> dict[str, str]: - """Formats a text response into a dictionary.""" - parts = resp.get("parts", []) - return {"Answer": "".join(parts)} - - -def _handle_schema_response(resp: dict[str, Any]) -> dict[str, Any]: - """Formats a schema response into a dictionary.""" - if "query" in resp: - return {"Question": resp["query"].get("question", "")} - elif "result" in resp: - datasources = resp["result"].get("datasources", []) - # Format each datasource and join them with newlines - formatted_sources = [_format_datasource_as_dict(ds) for ds in datasources] - return {"Schema Resolved": formatted_sources} - return {} - - -def _handle_data_response( - resp: dict[str, Any], max_query_result_rows: int -) -> dict[str, Any]: - """Formats a data response into a dictionary.""" - if "query" in resp: - query = resp["query"] - return { - "Retrieval Query": { - "Query Name": query.get("name", "N/A"), - "Question": query.get("question", "N/A"), - } - } - elif "generatedSql" in resp: - return {"SQL Generated": resp["generatedSql"]} - elif "result" in resp: - schema = resp["result"]["schema"] - headers = [field.get("name") for field in schema.get("fields", [])] - - all_rows = resp["result"].get("data", []) - total_rows = len(all_rows) - - compact_rows = [] - for row_dict in all_rows[:max_query_result_rows]: - row_values = [row_dict.get(header) for header in headers] - compact_rows.append(row_values) - - summary_string = f"Showing all {total_rows} rows." - if total_rows > max_query_result_rows: - summary_string = ( - f"Showing the first {len(compact_rows)} of {total_rows} total rows." - ) - - return { - "Data Retrieved": { - "headers": headers, - "rows": compact_rows, - "summary": summary_string, - } - } - - return {} - - -def _handle_error(resp: dict[str, Any]) -> dict[str, dict[str, Any]]: - """Formats an error response into a dictionary.""" - return { - "Error": { - "Code": resp.get("code", "N/A"), - "Message": resp.get("message", "No message provided."), - } - } - - -def _append_message( - messages: list[dict[str, Any]], - new_message: dict[str, Any], -): - """Appends a message to the list.""" - if not new_message: - return - - messages.append(new_message) diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py index b62c68358c..e52a90f96c 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py @@ -52,22 +52,32 @@ def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path): mock_post.return_value.__enter__.return_value = mock_response # 5. Call the function under test - result = data_insights_tool._get_stream( # pylint: disable=protected-access - url="fake_url", - ca_payload={}, - headers={}, - max_query_result_rows=50, + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + mock_settings = mock.Mock() + mock_settings.max_query_result_rows = 50 + result = data_insights_tool.ask_data_insights( + project_id="test-project", + user_query_with_context=case_data["user_question"], + table_references=[], + credentials=mock_creds, + settings=mock_settings, ) # 6. Assert that the final list of dicts matches the expected output - assert result == expected_final_list + assert result["status"] == "SUCCESS" + assert result["response"] == expected_final_list -@mock.patch.object(data_insights_tool, "_get_stream") +@mock.patch.object(data_insights_tool._gda_stream_util, "get_stream") def test_ask_data_insights_success(mock_get_stream): """Tests the success path of ask_data_insights using decorators.""" # 1. Configure the behavior of the mocked functions - mock_get_stream.return_value = "Final formatted string from stream" + mock_stream = [ + {"text": {"parts": ["response1"], "textType": "THOUGHT"}}, + {"text": {"parts": ["response2"], "textType": "FINAL_RESPONSE"}}, + ] + mock_get_stream.return_value = mock_stream # 2. Create mock inputs for the function call mock_creds = mock.Mock() @@ -86,7 +96,7 @@ def test_ask_data_insights_success(mock_get_stream): # 4. Assert the results are as expected assert result["status"] == "SUCCESS" - assert result["response"] == "Final formatted string from stream" + assert result["response"] == mock_stream mock_get_stream.assert_called_once() # Verify that the correct headers and client ID were passed to _get_stream @@ -96,7 +106,7 @@ def test_ask_data_insights_success(mock_get_stream): assert headers["Authorization"] == "Bearer fake-token" -@mock.patch.object(data_insights_tool, "_get_stream") +@mock.patch.object(data_insights_tool._gda_stream_util, "get_stream") def test_ask_data_insights_handles_exception(mock_get_stream): """Tests the exception path of ask_data_insights using decorators.""" # 1. Configure one of the mocks to raise an error @@ -120,158 +130,3 @@ def test_ask_data_insights_handles_exception(mock_get_stream): assert result["status"] == "ERROR" assert "API call failed!" in result["error_details"] mock_get_stream.assert_called_once() - - -@pytest.mark.parametrize( - "initial_messages, new_message, expected_list", - [ - pytest.param( - [{"Thinking": None}, {"Schema Resolved": {}}], - {"SQL Generated": "SELECT 1"}, - [ - {"Thinking": None}, - {"Schema Resolved": {}}, - {"SQL Generated": "SELECT 1"}, - ], - id="append_when_last_message_is_not_data", - ), - pytest.param( - [{"Thinking": None}, {"Data Retrieved": {"rows": [1]}}], - {"Data Retrieved": {"rows": [1, 2]}}, - [{"Thinking": None}, {"Data Retrieved": {"rows": [1, 2]}}], - id="replace_when_last_message_is_data", - ), - pytest.param( - [], - {"Answer": "First Message"}, - [{"Answer": "First Message"}], - id="append_to_an_empty_list", - ), - pytest.param( - [{"Data Retrieved": {}}], - {}, - [{"Data Retrieved": {}}], - id="should_not_append_an_empty_new_message", - ), - ], -) -def test_append_message(initial_messages, new_message, expected_list): - """Tests the logic of replacing the last message if it's a data message.""" - messages_copy = initial_messages.copy() - data_insights_tool._append_message(messages_copy, new_message) # pylint: disable=protected-access - assert messages_copy == expected_list - - -@pytest.mark.parametrize( - "response_dict, expected_output", - [ - pytest.param( - {"parts": ["The answer", " is 42."]}, - {"Answer": "The answer is 42."}, - id="multiple_parts", - ), - pytest.param( - {"parts": ["Hello"]}, {"Answer": "Hello"}, id="single_part" - ), - pytest.param({}, {"Answer": ""}, id="empty_response"), - ], -) -def test_handle_text_response(response_dict, expected_output): - """Tests the text response handler.""" - result = data_insights_tool._handle_text_response(response_dict) # pylint: disable=protected-access - assert result == expected_output - - -@pytest.mark.parametrize( - "response_dict, expected_output", - [ - pytest.param( - {"query": {"question": "What is the schema?"}}, - {"Question": "What is the schema?"}, - id="schema_query_path", - ), - pytest.param( - { - "result": { - "datasources": [{ - "bigqueryTableReference": { - "projectId": "p", - "datasetId": "d", - "tableId": "t", - }, - "schema": { - "fields": [{"name": "col1", "type": "STRING"}] - }, - }] - } - }, - { - "Schema Resolved": [{ - "source_name": "p.d.t", - "schema": { - "headers": ["Column", "Type", "Description", "Mode"], - "rows": [["col1", "STRING", "", ""]], - }, - }] - }, - id="schema_result_path", - ), - ], -) -def test_handle_schema_response(response_dict, expected_output): - """Tests different paths of the schema response handler.""" - result = data_insights_tool._handle_schema_response(response_dict) # pylint: disable=protected-access - assert result == expected_output - - -@pytest.mark.parametrize( - "response_dict, expected_output", - [ - pytest.param( - {"generatedSql": "SELECT 1;"}, - {"SQL Generated": "SELECT 1;"}, - id="format_generated_sql", - ), - pytest.param( - { - "result": { - "schema": {"fields": [{"name": "id"}, {"name": "name"}]}, - "data": [{"id": 1, "name": "A"}, {"id": 2, "name": "B"}], - } - }, - { - "Data Retrieved": { - "headers": ["id", "name"], - "rows": [[1, "A"], [2, "B"]], - "summary": "Showing all 2 rows.", - } - }, - id="format_data_result_table", - ), - ], -) -def test_handle_data_response(response_dict, expected_output): - """Tests different paths of the data response handler, including truncation.""" - result = data_insights_tool._handle_data_response(response_dict, 100) # pylint: disable=protected-access - assert result == expected_output - - -@pytest.mark.parametrize( - "response_dict, expected_output", - [ - pytest.param( - {"code": 404, "message": "Not Found"}, - {"Error": {"Code": 404, "Message": "Not Found"}}, - id="full_error_message", - ), - pytest.param( - {"code": 500}, - {"Error": {"Code": 500, "Message": "No message provided."}}, - id="error_with_missing_message", - ), - ], -) -def test_handle_error(response_dict, expected_output): - """Tests the error response handler.""" - result = data_insights_tool._handle_error(response_dict) # pylint: disable=protected-access - assert result == expected_output diff --git a/tests/unittests/tools/bigquery/test_data/ask_data_insights_penguins_highest_mass.yaml b/tests/unittests/tools/bigquery/test_data/ask_data_insights_penguins_highest_mass.yaml index 2f2603f573..ad4fd43193 100644 --- a/tests/unittests/tools/bigquery/test_data/ask_data_insights_penguins_highest_mass.yaml +++ b/tests/unittests/tools/bigquery/test_data/ask_data_insights_penguins_highest_mass.yaml @@ -6,128 +6,11 @@ mock_api_stream: | [{ "timestamp": "2025-07-17T17:25:28.231Z", "systemMessage": { - "schema": { - "query": { - "question": "Penguins on which island have the highest average body mass?" - } - } - } - } - , - { - "timestamp": "2025-07-17T17:25:29.406Z", - "systemMessage": { - "schema": { - "result": { - "datasources": [ - { - "bigqueryTableReference": { - "projectId": "bigframes-dev-perf", - "datasetId": "bigframes_testing_eu", - "tableId": "penguins" - }, - "schema": { - "fields": [ - { - "name": "species", - "type": "STRING", - "mode": "NULLABLE" - }, - { - "name": "island", - "type": "STRING", - "mode": "NULLABLE" - }, - { - "name": "culmen_length_mm", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "culmen_depth_mm", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "flipper_length_mm", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "body_mass_g", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "sex", - "type": "STRING", - "mode": "NULLABLE" - } - ] - } - } - ] - } - } - } - } - , - { - "timestamp": "2025-07-17T17:25:30.431Z", - "systemMessage": { - "data": { - "query": { - "question": "What is the average body mass for each island?", - "datasources": [ - { - "bigqueryTableReference": { - "projectId": "bigframes-dev-perf", - "datasetId": "bigframes_testing_eu", - "tableId": "penguins" - }, - "schema": { - "fields": [ - { - "name": "species", - "type": "STRING", - "mode": "NULLABLE" - }, - { - "name": "island", - "type": "STRING", - "mode": "NULLABLE" - }, - { - "name": "culmen_length_mm", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "culmen_depth_mm", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "flipper_length_mm", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "body_mass_g", - "type": "FLOAT64", - "mode": "NULLABLE" - }, - { - "name": "sex", - "type": "STRING", - "mode": "NULLABLE" - } - ] - } - } - ], - "name": "average_body_mass_by_island" - } + "text": { + "parts": [ + "Penguins on which island have the highest average body mass?" + ], + "textType": "THOUGHT" } } } @@ -141,38 +24,6 @@ mock_api_stream: | } } , - { - "timestamp": "2025-07-17T17:25:32.378Z", - "systemMessage": { - "data": { - "bigQueryJob": { - "projectId": "bigframes-dev-perf", - "jobId": "job_S4PGRwxO78_FrVmCHW_sklpeZFps", - "destinationTable": { - "projectId": "bigframes-dev-perf", - "datasetId": "_376b2bd1b83171a540d39ff3d58f39752e2724c9", - "tableId": "anonev_4a9PK1uHzAHwAOpSNOxMVhpUppM2sllR68riN6t41kM" - }, - "location": "EU", - "schema": { - "fields": [ - { - "name": "island", - "type": "STRING", - "mode": "NULLABLE" - }, - { - "name": "average_body_mass", - "type": "FLOAT", - "mode": "NULLABLE" - } - ] - } - } - } - } - } - , { "timestamp": "2025-07-17T17:25:32.664Z", "systemMessage": { @@ -212,125 +63,39 @@ mock_api_stream: | } } , - { - "timestamp": "2025-07-17T17:25:33.808Z", - "systemMessage": { - "chart": { - "query": { - "instructions": "Create a bar chart showing the average body mass for each island. The island should be on the x axis and the average body mass should be on the y axis.", - "dataResultName": "average_body_mass_by_island" - } - } - } - } - , - { - "timestamp": "2025-07-17T17:25:38.999Z", - "systemMessage": { - "chart": { - "result": { - "vegaConfig": { - "mark": { - "type": "bar", - "tooltip": true - }, - "encoding": { - "x": { - "field": "island", - "type": "nominal", - "title": "Island", - "axis": { - "labelOverlap": true - }, - "sort": {} - }, - "y": { - "field": "average_body_mass", - "type": "quantitative", - "title": "Average Body Mass", - "axis": { - "labelOverlap": true - }, - "sort": {} - } - }, - "title": "Average Body Mass for Each Island", - "data": { - "values": [ - { - "island": "Biscoe", - "average_body_mass": 4716.0179640718534 - }, - { - "island": "Dream", - "average_body_mass": 3712.9032258064512 - }, - { - "island": "Torgersen", - "average_body_mass": 3706.3725490196075 - } - ] - } - }, - "image": {} - } - } - } - } - , { "timestamp": "2025-07-17T17:25:40.018Z", "systemMessage": { "text": { "parts": [ "Penguins on Biscoe island have the highest average body mass, with an average of 4716.02g." - ] + ], + "textType": "FINAL_RESPONSE" } } } ] expected_output: -- Question: Penguins on which island have the highest average body mass? -- Schema Resolved: - - source_name: bigframes-dev-perf.bigframes_testing_eu.penguins - schema: - headers: - - Column - - Type - - Description - - Mode - rows: - - - species - - STRING - - '' - - NULLABLE - - - island - - STRING - - '' - - NULLABLE - - - culmen_length_mm - - FLOAT64 - - '' - - NULLABLE - - - culmen_depth_mm - - FLOAT64 - - '' - - NULLABLE - - - flipper_length_mm - - FLOAT64 - - '' - - NULLABLE - - - body_mass_g - - FLOAT64 - - '' - - NULLABLE - - - sex - - STRING - - '' - - NULLABLE -- Retrieval Query: - Query Name: average_body_mass_by_island - Question: What is the average body mass for each island? -- SQL Generated: "SELECT island, AVG(body_mass_g) AS average_body_mass\nFROM `bigframes-dev-perf`.`bigframes_testing_eu`.`penguins`\nGROUP BY island;" -- Answer: Penguins on Biscoe island have the highest average body mass, with an average of 4716.02g. +- text: + parts: + - 'Penguins on which island have the highest average body mass?' + textType: THOUGHT +- data: + generatedSql: "SELECT island, AVG(body_mass_g) AS average_body_mass\nFROM `bigframes-dev-perf`.`bigframes_testing_eu`.`penguins`\nGROUP BY island;" +- Data Retrieved: + headers: + - island + - average_body_mass + rows: + - - Biscoe + - '4716.017964071853' + - - Dream + - '3712.9032258064512' + - - Torgersen + - '3706.3725490196075' + summary: Showing all 3 rows. +- text: + parts: + - "Penguins on Biscoe island have the highest average body mass, with an average of 4716.02g." + textType: FINAL_RESPONSE diff --git a/tests/unittests/tools/data_agent/test_data_agent_tool.py b/tests/unittests/tools/data_agent/test_data_agent_tool.py index 54b3e8d327..47be0ab765 100644 --- a/tests/unittests/tools/data_agent/test_data_agent_tool.py +++ b/tests/unittests/tools/data_agent/test_data_agent_tool.py @@ -36,7 +36,14 @@ def test_list_accessible_data_agents_success(mock_requests): ) assert result["status"] == "SUCCESS" assert result["response"] == ["agent1", "agent2"] - mock_requests.get.assert_called_once() + mock_requests.get.assert_called_once_with( + "https://geminidataanalytics.googleapis.com/v1beta/projects/test-project/locations/global/dataAgents:listAccessible", + headers={ + "Authorization": "Bearer fake-token", + "Content-Type": "application/json", + "X-Goog-API-Client": "GOOGLE_ADK", + }, + ) @mock.patch.object(data_agent_tool, "requests", autospec=True) @@ -65,7 +72,14 @@ def test_get_data_agent_info_success(mock_requests): result = data_agent_tool.get_data_agent_info("agent_name", mock_creds) assert result["status"] == "SUCCESS" assert result["response"] == "agent_info" - mock_requests.get.assert_called_once() + mock_requests.get.assert_called_once_with( + "https://geminidataanalytics.googleapis.com/v1beta/agent_name", + headers={ + "Authorization": "Bearer fake-token", + "Content-Type": "application/json", + "X-Goog-API-Client": "GOOGLE_ADK", + }, + ) @mock.patch.object(data_agent_tool, "requests", autospec=True) @@ -80,7 +94,9 @@ def test_get_data_agent_info_exception(mock_requests): mock_requests.get.assert_called_once() -@mock.patch.object(data_agent_tool, "_get_stream", autospec=True) +@mock.patch.object( + data_agent_tool._gda_stream_util, "get_stream", autospec=True +) @mock.patch.object(data_agent_tool, "requests", autospec=True) @mock.patch.object(data_agent_tool, "get_data_agent_info", autospec=True) def test_ask_data_agent_success( @@ -91,8 +107,8 @@ def test_ask_data_agent_success( mock_creds.token = "fake-token" mock_get_agent_info.return_value = {"status": "SUCCESS", "response": {}} mock_get_stream.return_value = [ - {"Answer": "response1"}, - {"Answer": "response2"}, + {"text": {"parts": ["response1"], "textType": "THOUGHT"}}, + {"text": {"parts": ["response2"], "textType": "FINAL_RESPONSE"}}, ] mock_invocation_context = mock.Mock() mock_invocation_context.session.state = {} @@ -108,14 +124,31 @@ def test_ask_data_agent_success( ) assert result["status"] == "SUCCESS" assert result["response"] == [ - {"Answer": "response1"}, - {"Answer": "response2"}, + {"text": {"parts": ["response1"], "textType": "THOUGHT"}}, + {"text": {"parts": ["response2"], "textType": "FINAL_RESPONSE"}}, ] mock_get_agent_info.assert_called_once() - mock_get_stream.assert_called_once() + mock_get_stream.assert_called_once_with( + "https://geminidataanalytics.googleapis.com/v1beta/projects/p/locations/l:chat", + { + "messages": [{"userMessage": {"text": "query"}}], + "dataAgentContext": { + "dataAgent": "projects/p/locations/l/dataAgents/a", + }, + "clientIdEnum": "GOOGLE_ADK", + }, + { + "Authorization": "Bearer fake-token", + "Content-Type": "application/json", + "X-Goog-API-Client": "GOOGLE_ADK", + }, + mock_settings.max_query_result_rows, + ) -@mock.patch.object(data_agent_tool, "_get_stream", autospec=True) +@mock.patch.object( + data_agent_tool._gda_stream_util, "get_stream", autospec=True +) @mock.patch.object(data_agent_tool, "requests", autospec=True) @mock.patch.object(data_agent_tool, "get_data_agent_info", autospec=True) def test_ask_data_agent_exception( @@ -141,71 +174,3 @@ def test_ask_data_agent_exception( assert result["status"] == "ERROR" assert "Chat failed!" in result["error_details"] mock_get_stream.assert_called_once() - - -@pytest.mark.parametrize( - "case_file_path", - [ - pytest.param("test_data/ask_data_insights_penguins_highest_mass.yaml"), - ], -) -@mock.patch.object(requests.Session, "post") -def test_get_stream_from_file(mock_post, case_file_path): - """Runs a full integration test for the _get_stream function using data from a specific file.""" - # 1. Construct the full, absolute path to the data file - full_path = pathlib.Path(__file__).parent.parent / "bigquery" / case_file_path - - # 2. Load the test case data from the specified YAML file - with open(full_path, "r", encoding="utf-8") as f: - case_data = yaml.safe_load(f) - - # 3. Prepare the mock stream and expected output from the loaded data - mock_stream_str = case_data["mock_api_stream"] - fake_stream_lines = [ - line.encode("utf-8") for line in mock_stream_str.splitlines() - ] - # Load the expected output as a list of dictionaries, not a single string - expected_final_list = case_data["expected_output"] - data_retrieved = { - "Data Retrieved": { - "headers": ["island", "average_body_mass"], - "rows": [ - ["Biscoe", "4716.017964071853"], - ["Dream", "3712.9032258064512"], - ["Torgersen", "3706.3725490196075"], - ], - "summary": "Showing all 3 rows.", - } - } - expected_final_list.insert(-1, data_retrieved) - - # 4. Configure the mock for requests.post - mock_response = mock.Mock() - mock_response.iter_lines.return_value = fake_stream_lines - # Add raise_for_status mock which is called in the updated code - mock_response.raise_for_status.return_value = None - mock_post.return_value.__enter__.return_value = mock_response - - # 5. Call the function under test - result = data_agent_tool._get_stream( # pylint: disable=protected-access - url="fake_url", - ca_payload={}, - headers={}, - max_query_result_rows=50, - ) - - # 6. Assert that the final list of dicts matches the expected output - assert result == expected_final_list - - -def test_get_http_headers_includes_client_id(): - """Tests _get_http_headers includes the correct GDA client ID.""" - mock_creds = mock.Mock() - mock_creds.token = "fake-token" - - # pylint: disable=protected-access - headers = data_agent_tool._get_http_headers(mock_creds) - - assert headers["X-Goog-API-Client"] == "GOOGLE_ADK" - assert headers["Content-Type"] == "application/json" - assert headers["Authorization"] == "Bearer fake-token" diff --git a/tests/unittests/tools/test__gda_stream_util.py b/tests/unittests/tools/test__gda_stream_util.py new file mode 100644 index 0000000000..6945e67833 --- /dev/null +++ b/tests/unittests/tools/test__gda_stream_util.py @@ -0,0 +1,163 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest import mock + +from google.adk.tools import _gda_stream_util +import requests + + +class MockResponse: + + def __init__(self, lines): + self._lines = lines + + def iter_lines(self): + return iter(self._lines) + + def raise_for_status(self): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + +class GdaStreamUtilTest(unittest.TestCase): + + def test_extract_data_result_success(self): + msg = { + "systemMessage": {"data": {"result": {"data": [1, 2], "schema": {}}}} + } + self.assertEqual( + _gda_stream_util._extract_data_result(msg), + {"data": [1, 2], "schema": {}}, + ) + + def test_extract_data_result_failure(self): + self.assertIsNone(_gda_stream_util._extract_data_result({})) + self.assertIsNone( + _gda_stream_util._extract_data_result({"systemMessage": None}) + ) + self.assertIsNone( + _gda_stream_util._extract_data_result({"systemMessage": {"data": None}}) + ) + self.assertIsNone( + _gda_stream_util._extract_data_result( + {"systemMessage": {"data": {"result": None}}} + ) + ) + self.assertIsNone( + _gda_stream_util._extract_data_result( + {"systemMessage": {"data": {"result": {"no_data": 1}}}} + ) + ) + + def test_format_data_retrieved_simple(self): + result = { + "data": [{"col1": "val1", "col2": 10}], + "schema": {"fields": [{"name": "col1"}, {"name": "col2"}]}, + } + formatted = _gda_stream_util._format_data_retrieved(result, 10) + self.assertEqual( + formatted, + { + "Data Retrieved": { + "headers": ["col1", "col2"], + "rows": [["val1", 10]], + "summary": "Showing all 1 rows.", + } + }, + ) + + def test_format_data_retrieved_truncation(self): + result = { + "data": [{"col1": f"val{i}"} for i in range(5)], + "schema": {"fields": [{"name": "col1"}]}, + } + formatted = _gda_stream_util._format_data_retrieved(result, 2) + self.assertEqual( + formatted, + { + "Data Retrieved": { + "headers": ["col1"], + "rows": [["val0"], ["val1"]], + "summary": "Showing the first 2 of 5 total rows.", + } + }, + ) + + def test_format_data_retrieved_missing_schema(self): + result = {"data": [{"col1": "val1"}], "schema": None} + formatted = _gda_stream_util._format_data_retrieved(result, 10) + self.assertEqual( + formatted, + { + "Data Retrieved": { + "headers": ["col1"], + "rows": [["val1"]], + "summary": "Showing all 1 rows.", + } + }, + ) + + @mock.patch("requests.Session.post") + def test_get_stream(self, mock_post): + stream_lines = [ + b"[{", + b'"systemMessage": {"text": "msg1"}', + b"}", + b",", + b"{", + ( + b'"systemMessage": { "data": { "result": { "data": [{"a":1}],' + b' "schema": {"fields":[{"name":"a"}]}}}}' + ), + b"}", + b",", + b"{", + ( + b'"systemMessage": { "data": { "result": { "data": [{"b":2}],' + b' "schema": {"fields":[{"name":"b"}]}}}}' + ), + b"}", + b",", + b"{", + b'"systemMessage": {"text": "msg4"}', + b"}]", + ] + mock_post.return_value = MockResponse(stream_lines) + messages = _gda_stream_util.get_stream("url", {}, {}, 10) + self.assertEqual(len(messages), 4) + self.assertEqual(messages[0], {"text": "msg1"}) + self.assertEqual( + messages[1], {"Data Retrieved": "Intermediate result omitted"} + ) + self.assertEqual( + messages[2], + { + "Data Retrieved": { + "headers": ["b"], + "rows": [[2]], + "summary": "Showing all 1 rows.", + } + }, + ) + self.assertEqual(messages[3], {"text": "msg4"}) + + +if __name__ == "__main__": + unittest.main() From baf7efbaa92ce9d71152ea9ba7f5d0706277b171 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 18 May 2026 11:30:48 -0700 Subject: [PATCH 8/8] feat: Added config option to include tool calls/responses in conversation history passed to user simulator PiperOrigin-RevId: 917340645 --- .../simulation/llm_backed_user_simulator.py | 24 +++++++++++++++++-- .../test_llm_backed_user_simulator.py | 13 ++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py index fcfc4fd0ad..f5fe612aaa 100644 --- a/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py +++ b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py @@ -85,6 +85,12 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig): """, ) + include_function_calls: bool = Field( + default=False, + description="""Whether to include function calls and responses in the +conversation history prompt provided to the user simulator.""", + ) + @field_validator("custom_instructions") @classmethod def validate_custom_instructions(cls, value: str | None) -> str | None: @@ -132,13 +138,15 @@ def __init__( def _summarize_conversation( cls, events: list[Event], + include_function_calls: bool = False, ) -> str: """Summarize the conversation to add to the prompt. - Removes tool calls, responses, and thoughts. + Removes responses, thoughts, optionally tool calls and tool responses. Args: events: The conversation history to rewrite. + include_function_calls: Whether to include function calls and responses. Returns: The summarized conversation history as a string. @@ -151,6 +159,16 @@ def _summarize_conversation( for part in e.content.parts: if part.text and not part.thought: rewritten_dialogue.append(f"{author}: {part.text}") + elif include_function_calls and part.function_call: + rewritten_dialogue.append( + f"{author} called tool '{part.function_call.name}' with args:" + f" {part.function_call.args}" + ) + elif include_function_calls and part.function_response: + rewritten_dialogue.append( + f"Tool '{part.function_response.name}' returned:" + f" {part.function_response.response}" + ) return "\n\n".join(rewritten_dialogue) @@ -255,7 +273,9 @@ async def get_next_user_message( return NextUserMessage(status=Status.TURN_LIMIT_REACHED) # rewrite events for the user simulator - rewritten_dialogue = self._summarize_conversation(events) + rewritten_dialogue = self._summarize_conversation( + events, self._config.include_function_calls + ) # query the LLM for the next user message response, error_reason = await self._get_llm_response(rewritten_dialogue) diff --git a/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py index 2ff957509d..9d5882cc9f 100644 --- a/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py @@ -119,6 +119,19 @@ def test_convert_conversation_to_user_sim_pov(self): ) assert rewritten_dialogue == _EXPECTED_REWRITTEN_DIALOGUE_LONG + def test_summarize_conversation_with_function_calls(self): + """Tests _summarize_conversation with include_function_calls=True.""" + rewritten_dialogue = LlmBackedUserSimulator._summarize_conversation( + _INPUT_EVENTS, include_function_calls=True + ) + expected = ( + "user: Can you help me?\n\n" + "helpful_assistant called tool 'get_user_name' with args: None\n\n" + "Tool 'get_user_name' returned: {'name': 'John Doe'}\n\n" + "helpful_assistant: Hi John, what can I do for you?" + ) + assert rewritten_dialogue == expected + async def to_async_iter(items): for item in items: