diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 2d9b37e1f5..5d8e698c84 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -54,6 +54,8 @@ class WorkflowAgent(BaseAgent): # Class variable for the request info function name REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info" + _SESSION_STATE_KEY: ClassVar[str] = "workflow_agent" + _PENDING_REQUESTS_STATE_KEY: ClassVar[str] = "pending_request_info_events" @dataclass class RequestInfoFunctionArgs: @@ -258,6 +260,7 @@ async def _run_impl( An AgentResponse representing the workflow execution results. """ input_messages = normalize_messages_input(messages) + self._restore_pending_requests_from_session(session) if ( not any( @@ -291,10 +294,11 @@ async def _run_impl( ) # combine the messages session_messages: list[Message] = session_context.get_messages(include_input=True) + workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages output_events: list[WorkflowEvent[Any]] = [] async for event in self._run_core( - session_messages, + workflow_input_messages, checkpoint_id, checkpoint_storage, streaming=False, @@ -311,6 +315,7 @@ async def _run_impl( session_context._response = result # type: ignore[assignment] await self._run_after_providers(session=provider_session, context=session_context) + self._persist_pending_requests_to_session(session) return result async def _run_stream_impl( @@ -338,6 +343,7 @@ async def _run_stream_impl( AgentResponseUpdate objects representing the workflow execution progress. """ input_messages = normalize_messages_input(messages) + self._restore_pending_requests_from_session(session) if ( not any( @@ -372,9 +378,10 @@ async def _run_stream_impl( # combine the messages session_messages: list[Message] = session_context.get_messages(include_input=True) + workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages all_updates: list[AgentResponseUpdate] = [] async for event in self._run_core( - session_messages, + workflow_input_messages, checkpoint_id, checkpoint_storage, streaming=True, @@ -392,6 +399,7 @@ async def _run_stream_impl( session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment] await self._run_after_providers(session=provider_session, context=session_context) + self._persist_pending_requests_to_session(session) async def _run_core( self, @@ -425,6 +433,8 @@ async def _run_core( async for event in self.workflow.run( responses=function_responses, stream=True, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, function_invocation_kwargs=function_invocation_kwargs, client_kwargs=client_kwargs, ): @@ -432,6 +442,8 @@ async def _run_core( else: for event in await self.workflow.run( responses=function_responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, function_invocation_kwargs=function_invocation_kwargs, client_kwargs=client_kwargs, ): @@ -484,6 +496,60 @@ async def _run_core( # endregion Run Methods + def _restore_pending_requests_from_session(self, session: AgentSession | None) -> None: + """Load pending request-info events from the session state.""" + if session is None: + return + + agent_state = session.state.get(self._SESSION_STATE_KEY) + if not isinstance(agent_state, dict): + self.pending_requests.clear() + return + + pending_requests_payload = agent_state.get(self._PENDING_REQUESTS_STATE_KEY) + if not isinstance(pending_requests_payload, dict): + self.pending_requests.clear() + return + + restored_pending: dict[str, WorkflowEvent[Any]] = {} + for request_id, request_payload in pending_requests_payload.items(): + if isinstance(request_payload, WorkflowEvent): + restored_pending[request_id] = request_payload + continue + + if not isinstance(request_payload, dict): + logger.warning("Skipping malformed pending request payload for request_id '%s'.", request_id) + continue + + try: + restored_pending[request_id] = WorkflowEvent.from_dict(request_payload) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "Failed to restore pending request payload for request_id '%s': %s", + request_id, + exc, + ) + + self.pending_requests.clear() + self.pending_requests.update(restored_pending) + + def _persist_pending_requests_to_session(self, session: AgentSession | None) -> None: + """Persist pending request-info events to the session state.""" + if session is None: + return + + agent_state = session.state.setdefault(self._SESSION_STATE_KEY, {}) + if not isinstance(agent_state, dict): + logger.warning( + "Skipping pending request persistence because '%s' is not a mapping.", + self._SESSION_STATE_KEY, + ) + return + + agent_state[self._PENDING_REQUESTS_STATE_KEY] = { + request_id: event.to_dict() for request_id, event in self.pending_requests.items() + } + def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]: """Process pending requests by extracting function responses and updating state. diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 3dcdd26c86..4fad39eae4 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -293,6 +293,94 @@ async def test_end_to_end_request_info_handling(self): # Verify cleanup - pending requests should be cleared after function response handling assert len(agent.pending_requests) == 0 + async def test_request_info_resume_after_session_restore_with_checkpoint(self): + """Pending request metadata in AgentSession should resume the same request_id after restore.""" + from agent_framework import InMemoryCheckpointStorage + + simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False) + requesting_executor = RequestingExecutor(id="requester", streaming=False) + checkpoint_storage = InMemoryCheckpointStorage() + + workflow = ( + WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=checkpoint_storage) + .add_edge(simple_executor, requesting_executor) + .build() + ) + agent = WorkflowAgent(workflow=workflow, name="Request Restore Test Agent") + session = AgentSession() + + updates: list[AgentResponseUpdate] = [] + async for update in agent.run("Start request", stream=True, session=session): + updates.append(update) + + approval_update = next( + ( + update + for update in updates + if any(content.type == "function_approval_request" for content in update.contents) + ), + None, + ) + assert approval_update is not None, "Should have received a request_info approval request" + + function_call = next(content for content in approval_update.contents if content.type == "function_call") + approval_request = next( + content for content in approval_update.contents if content.type == "function_approval_request" + ) + request_id = approval_request.id + assert request_id is not None + assert function_call.call_id == request_id + + checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + checkpoint_with_request = next( + (checkpoint for checkpoint in checkpoints if request_id in checkpoint.pending_request_info_events), + None, + ) + assert checkpoint_with_request is not None + + serialized_session = session.to_dict() + workflow_agent_state = serialized_session["state"].get("workflow_agent", {}) + pending_state = workflow_agent_state.get("pending_request_info_events", {}) + assert request_id in pending_state + + restored_session = AgentSession.from_dict(serialized_session) + + restored_simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False) + restored_requesting_executor = RequestingExecutor(id="requester", streaming=False) + restored_workflow = ( + WorkflowBuilder(start_executor=restored_simple_executor, checkpoint_storage=checkpoint_storage) + .add_edge(restored_simple_executor, restored_requesting_executor) + .build() + ) + restored_agent = WorkflowAgent(workflow=restored_workflow, name="Request Restore Test Agent") + + response_args = WorkflowAgent.RequestInfoFunctionArgs( + request_id=request_id, + data="User provided answer", + ).to_dict() + approval_response = Content.from_function_approval_response( + approved=True, + id=request_id, + function_call=Content.from_function_call( + call_id=request_id, + name=WorkflowAgent.REQUEST_INFO_FUNCTION_NAME, + arguments=response_args, + ), + ) + response_message = Message(role="user", contents=[approval_response]) + + continuation_result = await restored_agent.run( + response_message, + session=restored_session, + checkpoint_id=checkpoint_with_request.checkpoint_id, + checkpoint_storage=checkpoint_storage, + ) + + assert isinstance(continuation_result, AgentResponse) + response_texts = [message.text for message in continuation_result.messages if message.text] + assert any("Request completed with response: User provided answer" in text for text in response_texts) + assert len(restored_agent.pending_requests) == 0 + def test_workflow_as_agent_method(self) -> None: """Test that Workflow.as_agent() creates a properly configured WorkflowAgent.""" # Create a simple workflow