Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def _part_has_payload(part: types.Part) -> bool:
return True
if part.inline_data and part.inline_data.data:
return True
if part.function_response:
return True
if part.file_data and (part.file_data.file_uri or part.file_data.data):
return True
return False
Expand Down Expand Up @@ -1293,12 +1295,15 @@ def _model_response_to_chunk(
if not func_name and not func_args:
continue

yield FunctionChunk(
id=tool_call.id,
name=func_name,
args=func_args,
index=func_index,
), finish_reason
yield (
FunctionChunk(
id=tool_call.id,
name=func_name,
args=func_args,
index=func_index,
),
finish_reason,
)

if finish_reason and not (message_content or tool_calls):
yield None, finish_reason
Expand All @@ -1310,12 +1315,17 @@ def _model_response_to_chunk(
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]),
), None
yield (
UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
cached_prompt_tokens=_extract_cached_prompt_tokens(
response["usage"]
),
),
None,
)


def _model_response_to_generate_content_response(
Expand Down Expand Up @@ -1708,7 +1718,12 @@ def _warn_gemini_via_litellm(model_string: str) -> None:
# Check if warning should be suppressed via environment variable
if os.environ.get(
"ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", ""
).strip().lower() in ("1", "true", "yes", "on"):
).strip().lower() in (
"1",
"true",
"yes",
"on",
):
return

warnings.warn(
Expand Down Expand Up @@ -1812,9 +1827,12 @@ async def generate_content_async(
logger.debug(_build_request_log(llm_request))

effective_model = llm_request.model or self.model
messages, tools, response_format, generation_params = (
await _get_completion_inputs(llm_request, effective_model)
)
(
messages,
tools,
response_format,
generation_params,
) = await _get_completion_inputs(llm_request, effective_model)
normalized_messages = _normalize_ollama_chat_messages(
messages,
model=effective_model,
Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/models/test_litellm_function_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from google.adk.models.lite_llm import _part_has_payload
from google.genai import types
import pytest


def test_part_has_payload_with_function_response():
part = types.Part.from_function_response(
name="test_fn", response={"result": "success"}
)
assert _part_has_payload(part) is True


def test_part_has_payload_without_payload():
part = types.Part()
assert _part_has_payload(part) is False
Loading