Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/google/adk/a2a/converters/event_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,10 @@ def convert_event_to_a2a_events(

# Handle regular message content
message = convert_event_to_a2a_message(
event, invocation_context, part_converter=part_converter
event,
invocation_context,
part_converter=part_converter,
role=Role.user if event.author == "user" else Role.agent,
)
if message:
running_event = _create_status_update_event(
Expand Down
24 changes: 22 additions & 2 deletions src/google/adk/evaluation/simulation/llm_backed_user_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
""",
)

include_function_calls: bool = Field(
default=False,
description="""Whether to include function calls and responses in the
conversation history prompt provided to the user simulator.""",
)

@field_validator("custom_instructions")
@classmethod
def validate_custom_instructions(cls, value: str | None) -> str | None:
Expand Down Expand Up @@ -132,13 +138,15 @@ def __init__(
def _summarize_conversation(
cls,
events: list[Event],
include_function_calls: bool = False,
) -> str:
"""Summarize the conversation to add to the prompt.

Removes tool calls, responses, and thoughts.
Removes responses, thoughts, optionally tool calls and tool responses.

Args:
events: The conversation history to rewrite.
include_function_calls: Whether to include function calls and responses.

Returns:
The summarized conversation history as a string.
Expand All @@ -151,6 +159,16 @@ def _summarize_conversation(
for part in e.content.parts:
if part.text and not part.thought:
rewritten_dialogue.append(f"{author}: {part.text}")
elif include_function_calls and part.function_call:
rewritten_dialogue.append(
f"{author} called tool '{part.function_call.name}' with args:"
f" {part.function_call.args}"
)
elif include_function_calls and part.function_response:
rewritten_dialogue.append(
f"Tool '{part.function_response.name}' returned:"
f" {part.function_response.response}"
)

return "\n\n".join(rewritten_dialogue)

Expand Down Expand Up @@ -255,7 +273,9 @@ async def get_next_user_message(
return NextUserMessage(status=Status.TURN_LIMIT_REACHED)

# rewrite events for the user simulator
rewritten_dialogue = self._summarize_conversation(events)
rewritten_dialogue = self._summarize_conversation(
events, self._config.include_function_calls
)

# query the LLM for the next user message
response, error_reason = await self._get_llm_response(rewritten_dialogue)
Expand Down
27 changes: 20 additions & 7 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,14 @@ async def _content_to_message_param(
if isinstance(response, str)
else _safe_json_serialize(response)
)
# gemma4 requires role='tool_responses' for recognizing function_response parts as responses
# from the tool call, instead of OpenAI-compatible 'tool' role used by other models.
# Earlier Gemma versions before version 4 do not support tool use,
# so this check is intentionally scoped to only look for "gemma4" in the model name.
tool_role = "tool_responses" if "gemma4" in model.lower() else "tool"
tool_messages.append(
ChatCompletionToolMessage(
role="tool",
role=tool_role,
tool_call_id=part.function_response.id,
content=response_content,
)
Expand All @@ -824,6 +829,7 @@ async def _content_to_message_param(
follow_up = await _content_to_message_param(
types.Content(role=content.role, parts=non_tool_parts),
provider=provider,
model=model,
)
follow_up_messages = (
follow_up if isinstance(follow_up, list) else [follow_up]
Expand Down Expand Up @@ -934,12 +940,16 @@ async def _content_to_message_param(
)


def _ensure_tool_results(messages: List[Message]) -> List[Message]:
def _ensure_tool_results(messages: List[Message], model: str) -> List[Message]:
"""Insert placeholder tool messages for missing tool results.

LiteLLM-backed providers like OpenAI and Anthropic reject histories where an
assistant tool call is not followed by tool responses before the next
non-tool message. This helps recover from interrupted tool execution.

For models that expect a different tool response role (e.g. Gemma4 models,
which require 'tool_responses' instead of 'tool'), the role is adjusted
accordingly.
"""
if not messages:
return messages
Expand All @@ -948,17 +958,19 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:

healed_messages: List[Message] = []
pending_tool_call_ids: List[str] = []
expected_tool_role = "tool_responses" if "gemma4" in model.lower() else "tool"

for message in messages:
role = message.get("role")
if pending_tool_call_ids and role != "tool":

if pending_tool_call_ids and role != expected_tool_role:
logger.warning(
"Missing tool results for tool_call_id(s): %s",
pending_tool_call_ids,
)
healed_messages.extend(
ChatCompletionToolMessage(
role="tool",
role=expected_tool_role,
tool_call_id=tool_call_id,
content=_MISSING_TOOL_RESULT_MESSAGE,
)
Expand All @@ -971,21 +983,22 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:
pending_tool_call_ids = [
tool_call.get("id") for tool_call in tool_calls if tool_call.get("id")
]
elif role == "tool":
elif role == expected_tool_role:
tool_call_id = message.get("tool_call_id")
if tool_call_id in pending_tool_call_ids:
pending_tool_call_ids.remove(tool_call_id)

healed_messages.append(message)

# Final block also uses expected_tool_role
if pending_tool_call_ids:
logger.warning(
"Missing tool results for tool_call_id(s): %s",
pending_tool_call_ids,
)
healed_messages.extend(
ChatCompletionToolMessage(
role="tool",
role=expected_tool_role,
tool_call_id=tool_call_id,
content=_MISSING_TOOL_RESULT_MESSAGE,
)
Expand Down Expand Up @@ -1905,7 +1918,7 @@ async def _get_completion_inputs(
content=llm_request.config.system_instruction,
),
)
messages = _ensure_tool_results(messages)
messages = _ensure_tool_results(messages, model)

# 2. Convert tool declarations
tools: Optional[List[Dict]] = None
Expand Down
10 changes: 5 additions & 5 deletions src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,12 @@ def trace_call_llm(
except AttributeError:
pass

try:
llm_response_json = llm_response.model_dump_json(exclude_none=True)
except Exception: # pylint: disable=broad-exception-caught
llm_response_json = '<not serializable>'

if _should_add_request_response_to_spans():
try:
llm_response_json = llm_response.model_dump_json(exclude_none=True)
except Exception: # pylint: disable=broad-exception-caught
llm_response_json = '<not serializable>'

span.set_attribute(
'gcp.vertex.agent.llm_response',
llm_response_json,
Expand Down
144 changes: 144 additions & 0 deletions src/google/adk/tools/_gda_stream_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import json
from typing import Any

import requests


def get_stream(
url: str,
ca_payload: dict[str, Any],
headers: dict[str, str],
max_query_result_rows: int,
) -> list[dict[str, Any]]:
"""Sends a JSON request to a streaming API and returns a list of messages."""
with requests.Session() as s:
accumulator = ""
messages = []
data_msg_idx = -1

with s.post(url, json=ca_payload, headers=headers, stream=True) as resp:
resp.raise_for_status()
for line in resp.iter_lines():
if not line:
continue

decoded_line = line.decode("utf-8")

if decoded_line == "[{":
accumulator = "{"
elif decoded_line == "}]":
accumulator += "}"
elif decoded_line == ",":
continue
else:
accumulator += decoded_line

try:
data_json = json.loads(accumulator)
except ValueError:
continue

accumulator = ""

if not isinstance(data_json, dict):
messages.append(data_json)
continue

processed_msg = None
data_result = _extract_data_result(data_json)
if data_result is not None:
processed_msg = _format_data_retrieved(
data_result, max_query_result_rows
)
if data_msg_idx >= 0:
messages[data_msg_idx] = {
"Data Retrieved": "Intermediate result omitted"
}
data_msg_idx = len(messages)
elif isinstance(data_json.get("systemMessage"), dict):
processed_msg = data_json["systemMessage"]
else:
processed_msg = data_json

if processed_msg is not None:
messages.append(processed_msg)

return messages


def _extract_data_result(msg: dict[str, Any]) -> dict[str, Any] | None:
"""Attempts to find the result.data deep inside the generic dict."""
sm = msg.get("systemMessage")
if not isinstance(sm, dict):
return None
data = sm.get("data")
if not isinstance(data, dict):
return None
result = data.get("result")
if not isinstance(result, dict):
return None
if "data" in result and isinstance(result["data"], list):
return result
return None


def _format_data_retrieved(
result: dict[str, Any], max_rows: int
) -> dict[str, Any]:
"""Transforms the raw result dict into the simplified Toolbox format."""
raw_data = result.get("data", [])

fields = []
schema = result.get("schema")
if isinstance(schema, dict):
schema_fields = schema.get("fields")
if isinstance(schema_fields, list):
fields = schema_fields

headers = []
for f in fields:
if isinstance(f, dict):
name = f.get("name")
if isinstance(name, str):
headers.append(name)

if not headers and raw_data:
first_row = raw_data[0]
if isinstance(first_row, dict):
headers = list(first_row.keys())

total_rows = len(raw_data)
num_to_display = min(total_rows, max_rows)

rows = []
for r in raw_data[:num_to_display]:
if isinstance(r, dict):
row = [r.get(h) for h in headers]
rows.append(row)

summary = f"Showing all {total_rows} rows."
if total_rows > max_rows:
summary = f"Showing the first {num_to_display} of {total_rows} total rows."

return {
"Data Retrieved": {
"headers": headers,
"rows": rows,
"summary": summary,
}
}
Loading
Loading