diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 76fa100b..0052c30d 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -1266,7 +1266,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe def _convert_agent_state_to_lc(state: AgentState) -> LC_AgentState[Any]: - messages = [_map_message_to_langchain(m) for m in state.response.messages] + messages = [_map_message_to_langchain(m) for m in state.messages] return LC_AgentState(messages=messages) @@ -1627,14 +1627,9 @@ def _convert_agent_state_from_langchain( messages = state["messages"] total_tokens_counter = _get_approximate_token_counter(model) total_tokens = total_tokens_counter(messages) - - response = AgentResponse[Any | None]( - messages=[_map_message_from_langchain(m) for m in state["messages"]], - structured_output=state.get("structured_response"), - ) - + messages = [_map_message_from_langchain(m) for m in state["messages"]] return AgentState( - response=response, + messages=messages, total_steps=len(messages), token_count=total_tokens, ) diff --git a/splunklib/ai/middleware.py b/splunklib/ai/middleware.py index 8814c5d6..0231dbb6 100644 --- a/splunklib/ai/middleware.py +++ b/splunklib/ai/middleware.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from collections.abc import Awaitable, Callable +from collections.abc import Sequence, Awaitable, Callable from dataclasses import dataclass from typing import Any, override @@ -35,7 +35,7 @@ class AgentState: """AgentState is available through certain middlewares and contains information about the current state of an agent execution.""" # holds messages exchanged so far in the conversation - response: AgentResponse[Any | None] + messages: Sequence[BaseMessage] # steps taken so far in the conversation total_steps: int # tokens used so far in the conversation @@ -96,7 +96,7 @@ def __post_init__(self) -> None: @dataclass(frozen=True) class AgentRequest: - messages: list[BaseMessage] + messages: Sequence[BaseMessage] AgentMiddlewareHandler = Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]] diff --git a/tests/integration/ai/test_agent.py b/tests/integration/ai/test_agent.py index 1f9ea591..ad906dbd 100644 --- a/tests/integration/ai/test_agent.py +++ b/tests/integration/ai/test_agent.py @@ -532,7 +532,7 @@ async def _model_call_middleware( req: ModelRequest, _handler: ModelMiddlewareHandler ) -> ModelResponse: if after_subagent_call: - msgs = req.state.response.messages + msgs = req.state.messages assert isinstance(msgs[-1], SubagentMessage) assert isinstance(msgs[-1].result, SubagentFailureResult) diff --git a/tests/integration/ai/test_conversation_store.py b/tests/integration/ai/test_conversation_store.py index a5c10b34..77a756f2 100644 --- a/tests/integration/ai/test_conversation_store.py +++ b/tests/integration/ai/test_conversation_store.py @@ -66,9 +66,9 @@ async def _model_middleware( if after_first_call: # Previous messages included. - assert len(request.state.response.messages) == 3 + assert len(request.state.messages) == 3 else: - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 return await handler(request) @agent_middleware @@ -166,7 +166,7 @@ async def _model_middleware( nonlocal model_middleware_called model_middleware_called = True - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 return await handler(request) async with Agent( @@ -276,9 +276,9 @@ async def _model_middleware( nonlocal after_first_call if after_first_call: - assert len(request.state.response.messages) == 3 + assert len(request.state.messages) == 3 else: - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 after_first_call = True return await handler(request) @@ -347,9 +347,9 @@ async def _model_middleware( nonlocal after_first_call if after_first_call: - assert len(request.state.response.messages) == 3 + assert len(request.state.messages) == 3 else: - assert len(request.state.response.messages) == 1 + assert len(request.state.messages) == 1 after_first_call = True return await handler(request) diff --git a/tests/integration/ai/test_hooks.py b/tests/integration/ai/test_hooks.py index ad22a75b..8ad1601e 100644 --- a/tests/integration/ai/test_hooks.py +++ b/tests/integration/ai/test_hooks.py @@ -47,7 +47,7 @@ def test_hook_before(req: ModelRequest) -> None: hook_calls += 1 assert req.system_message.startswith("Your name is stefan") - assert len(req.state.response.messages) == 1 + assert len(req.state.messages) == 1 @before_model async def test_async_hook_before(req: ModelRequest) -> None: @@ -55,7 +55,7 @@ async def test_async_hook_before(req: ModelRequest) -> None: hook_calls += 1 assert req.system_message.startswith("Your name is stefan") - assert len(req.state.response.messages) == 1 + assert len(req.state.messages) == 1 @after_model def test_hook_after(resp: ModelResponse) -> None: diff --git a/tests/integration/ai/test_middleware.py b/tests/integration/ai/test_middleware.py index d699bb5b..b2adfed9 100644 --- a/tests/integration/ai/test_middleware.py +++ b/tests/integration/ai/test_middleware.py @@ -78,7 +78,7 @@ async def test_middleware( assert call.args == {"city": "Krakow"} state = request.state - assert len(state.response.messages) == 2 + assert len(state.messages) == 2 response = await handler(request) assert isinstance(response.result, ToolResult) @@ -699,10 +699,7 @@ async def mutating_middleware( ) -> ModelResponse: new_state = replace( request.state, - response=replace( - request.state.response, - messages=[HumanMessage(content="What is the capital of France?")], - ), + messages=[HumanMessage(content="What is the capital of France?")], ) return await handler(replace(request, state=new_state)) diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index e97c67c7..bd998075 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -48,7 +48,7 @@ def _make_agent_request() -> AgentRequest: def _make_model_request(token_count: int = 0, total_steps: int = 0) -> ModelRequest: state = AgentState( - response=AgentResponse(messages=[], structured_output=None), + messages=[], total_steps=total_steps, token_count=token_count, ) @@ -141,7 +141,7 @@ async def test_timeout_fires_when_deadline_exceeded(self) -> None: mw = TimeoutLimitMiddleware(60.0) mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past - state = AgentState(response=AgentResponse(messages=[], structured_output=None), total_steps=0, token_count=0) + state = AgentState(messages=[], total_steps=0, token_count=0) request = ModelRequest(system_message="", state=state) with self.assertRaises(TimeoutExceededException): @@ -166,4 +166,3 @@ async def test_raises_when_steps_in_request_reach_limit(self) -> None: await mw.model_middleware(_make_model_request(total_steps=2), _noop_model_handler) with self.assertRaises(StepsLimitExceededException): await mw.model_middleware(_make_model_request(total_steps=3), _noop_model_handler) -