Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/copilot/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass, field
from typing import Any, Literal, TypeVar, get_type_hints, overload

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"]

Expand Down Expand Up @@ -211,7 +211,21 @@ async def wrapped_handler(invocation: ToolInvocation) -> ToolResult:
if takes_params:
args = invocation.arguments or {}
if ptype is not None and _is_pydantic_model(ptype):
call_args.append(ptype.model_validate(args))
try:
call_args.append(ptype.model_validate(args))
except ValidationError as exc:
# Highlight input validation problems to the LLM.
parts = []
for err in exc.errors():
loc = ".".join(map(str, err["loc"]))
msg = err["msg"]
parts.append(f"{loc}: {msg}" if loc else msg)
return ToolResult(
text_result_for_llm="Invalid tool arguments:\n" + "\n".join(parts),
result_type="failure",
error=str(exc),
tool_telemetry={},
)
else:
call_args.append(args)
if takes_invocation:
Expand Down
87 changes: 86 additions & 1 deletion python/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json

import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator

from copilot import define_tool
from copilot.tools import (
Expand Down Expand Up @@ -197,6 +197,91 @@ def failing_tool(params: Params, invocation: ToolInvocation) -> str:
# But the actual error is stored internally
assert result.error == "secret error message"

async def test_validation_error_is_surfaced_to_llm(self):
class Params(BaseModel):
username: str

@field_validator("username")
@classmethod
def check_username(cls, v: str) -> str:
if v == "admin":
raise ValueError("username 'admin' is reserved")
return v

@define_tool("validate", description="A validating tool")
def validating_tool(params: Params) -> str:
return "ok"

invocation = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="validate",
arguments={"username": "admin"},
)

result = await validating_tool.handler(invocation)

assert result.result_type == "failure"
assert result.text_result_for_llm.startswith("Invalid tool arguments:")
assert "username 'admin' is reserved" in result.text_result_for_llm
# Full detail is retained in the debug field.
assert result.error is not None

async def test_validation_error_extra_forbid_includes_field_name(self):
class Params(BaseModel):
model_config = ConfigDict(extra="forbid")

request: str

@define_tool("strict", description="A strict tool")
def strict_tool(params: Params) -> str:
return "ok"

invocation = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="strict",
arguments={"request": "ok", "extra_field": "unexpected"},
)

result = await strict_tool.handler(invocation)

assert result.result_type == "failure"
assert result.text_result_for_llm.startswith("Invalid tool arguments:")
# The offending key name is carried in `loc` even though the generic
# message is "Extra inputs are not permitted".
assert "extra_field" in result.text_result_for_llm
assert result.error is not None

async def test_validation_error_from_handler_body_is_redacted(self):
class Params(BaseModel):
pass

class Internal(BaseModel):
count: int

@define_tool("body", description="A tool that validates internally")
def body_tool(params: Params) -> str:
Internal.model_validate({"count": "secret-not-an-int"})
return "ok"

invocation = ToolInvocation(
session_id="s1",
tool_call_id="c1",
tool_name="body",
arguments={},
)

result = await body_tool.handler(invocation)

assert result.result_type == "failure"
# A ValidationError from the handler body must not be surfaced as an
# argument-validation error; it stays redacted like any other exception.
assert not result.text_result_for_llm.startswith("Invalid tool arguments:")
assert "secret-not-an-int" not in result.text_result_for_llm
assert "error" in result.text_result_for_llm.lower()
assert result.error is not None

async def test_function_style_api(self):
class Params(BaseModel):
value: str
Expand Down
Loading