diff --git a/python/copilot/tools.py b/python/copilot/tools.py index a82a48b1e9..620a8cd58a 100644 --- a/python/copilot/tools.py +++ b/python/copilot/tools.py @@ -13,7 +13,7 @@ from dataclasses import dataclass, field from typing import Any, Literal, TypeVar, get_type_hints, overload -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"] @@ -211,7 +211,21 @@ async def wrapped_handler(invocation: ToolInvocation) -> ToolResult: if takes_params: args = invocation.arguments or {} if ptype is not None and _is_pydantic_model(ptype): - call_args.append(ptype.model_validate(args)) + try: + call_args.append(ptype.model_validate(args)) + except ValidationError as exc: + # Highlight input validation problems to the LLM. + parts = [] + for err in exc.errors(): + loc = ".".join(map(str, err["loc"])) + msg = err["msg"] + parts.append(f"{loc}: {msg}" if loc else msg) + return ToolResult( + text_result_for_llm="Invalid tool arguments:\n" + "\n".join(parts), + result_type="failure", + error=str(exc), + tool_telemetry={}, + ) else: call_args.append(args) if takes_invocation: diff --git a/python/test_tools.py b/python/test_tools.py index d583b59c01..c5230385f2 100644 --- a/python/test_tools.py +++ b/python/test_tools.py @@ -3,7 +3,7 @@ import json import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from copilot import define_tool from copilot.tools import ( @@ -197,6 +197,91 @@ def failing_tool(params: Params, invocation: ToolInvocation) -> str: # But the actual error is stored internally assert result.error == "secret error message" + async def test_validation_error_is_surfaced_to_llm(self): + class Params(BaseModel): + username: str + + @field_validator("username") + @classmethod + def check_username(cls, v: str) -> str: + if v == "admin": + raise ValueError("username 'admin' is reserved") + return v + + @define_tool("validate", description="A validating tool") + def validating_tool(params: Params) -> str: + return "ok" + + invocation = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="validate", + arguments={"username": "admin"}, + ) + + result = await validating_tool.handler(invocation) + + assert result.result_type == "failure" + assert result.text_result_for_llm.startswith("Invalid tool arguments:") + assert "username 'admin' is reserved" in result.text_result_for_llm + # Full detail is retained in the debug field. + assert result.error is not None + + async def test_validation_error_extra_forbid_includes_field_name(self): + class Params(BaseModel): + model_config = ConfigDict(extra="forbid") + + request: str + + @define_tool("strict", description="A strict tool") + def strict_tool(params: Params) -> str: + return "ok" + + invocation = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="strict", + arguments={"request": "ok", "extra_field": "unexpected"}, + ) + + result = await strict_tool.handler(invocation) + + assert result.result_type == "failure" + assert result.text_result_for_llm.startswith("Invalid tool arguments:") + # The offending key name is carried in `loc` even though the generic + # message is "Extra inputs are not permitted". + assert "extra_field" in result.text_result_for_llm + assert result.error is not None + + async def test_validation_error_from_handler_body_is_redacted(self): + class Params(BaseModel): + pass + + class Internal(BaseModel): + count: int + + @define_tool("body", description="A tool that validates internally") + def body_tool(params: Params) -> str: + Internal.model_validate({"count": "secret-not-an-int"}) + return "ok" + + invocation = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="body", + arguments={}, + ) + + result = await body_tool.handler(invocation) + + assert result.result_type == "failure" + # A ValidationError from the handler body must not be surfaced as an + # argument-validation error; it stays redacted like any other exception. + assert not result.text_result_for_llm.startswith("Invalid tool arguments:") + assert "secret-not-an-int" not in result.text_result_for_llm + assert "error" in result.text_result_for_llm.lower() + assert result.error is not None + async def test_function_style_api(self): class Params(BaseModel): value: str