Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions python/packages/core/agent_framework/_workflows/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class WorkflowAgent(BaseAgent):

# Class variable for the request info function name
REQUEST_INFO_FUNCTION_NAME: ClassVar[str] = "request_info"
_SESSION_STATE_KEY: ClassVar[str] = "workflow_agent"
_PENDING_REQUESTS_STATE_KEY: ClassVar[str] = "pending_request_info_events"

@dataclass
class RequestInfoFunctionArgs:
Expand Down Expand Up @@ -258,6 +260,7 @@ async def _run_impl(
An AgentResponse representing the workflow execution results.
"""
input_messages = normalize_messages_input(messages)
self._restore_pending_requests_from_session(session)

if (
not any(
Expand Down Expand Up @@ -291,10 +294,11 @@ async def _run_impl(
)
# combine the messages
session_messages: list[Message] = session_context.get_messages(include_input=True)
workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages

output_events: list[WorkflowEvent[Any]] = []
async for event in self._run_core(
session_messages,
workflow_input_messages,
checkpoint_id,
checkpoint_storage,
streaming=False,
Expand All @@ -311,6 +315,7 @@ async def _run_impl(
session_context._response = result # type: ignore[assignment]

await self._run_after_providers(session=provider_session, context=session_context)
self._persist_pending_requests_to_session(session)
return result

async def _run_stream_impl(
Expand Down Expand Up @@ -338,6 +343,7 @@ async def _run_stream_impl(
AgentResponseUpdate objects representing the workflow execution progress.
"""
input_messages = normalize_messages_input(messages)
self._restore_pending_requests_from_session(session)

if (
not any(
Expand Down Expand Up @@ -372,9 +378,10 @@ async def _run_stream_impl(
# combine the messages

session_messages: list[Message] = session_context.get_messages(include_input=True)
workflow_input_messages = input_messages if bool(self.pending_requests) else session_messages
all_updates: list[AgentResponseUpdate] = []
async for event in self._run_core(
session_messages,
workflow_input_messages,
checkpoint_id,
checkpoint_storage,
streaming=True,
Expand All @@ -392,6 +399,7 @@ async def _run_stream_impl(
session_context._response = AgentResponse.from_updates(all_updates) # type: ignore[assignment]

await self._run_after_providers(session=provider_session, context=session_context)
self._persist_pending_requests_to_session(session)

async def _run_core(
self,
Expand Down Expand Up @@ -425,13 +433,17 @@ async def _run_core(
async for event in self.workflow.run(
responses=function_responses,
stream=True,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
yield event
else:
for event in await self.workflow.run(
responses=function_responses,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
Expand Down Expand Up @@ -484,6 +496,60 @@ async def _run_core(

# endregion Run Methods

def _restore_pending_requests_from_session(self, session: AgentSession | None) -> None:
"""Load pending request-info events from the session state."""
if session is None:
return

agent_state = session.state.get(self._SESSION_STATE_KEY)
if not isinstance(agent_state, dict):
self.pending_requests.clear()
return

pending_requests_payload = agent_state.get(self._PENDING_REQUESTS_STATE_KEY)
if not isinstance(pending_requests_payload, dict):
self.pending_requests.clear()
return

restored_pending: dict[str, WorkflowEvent[Any]] = {}
for request_id, request_payload in pending_requests_payload.items():
if isinstance(request_payload, WorkflowEvent):
restored_pending[request_id] = request_payload
continue

if not isinstance(request_payload, dict):
logger.warning("Skipping malformed pending request payload for request_id '%s'.", request_id)
continue

try:
restored_pending[request_id] = WorkflowEvent.from_dict(request_payload)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"Failed to restore pending request payload for request_id '%s': %s",
request_id,
exc,
)

self.pending_requests.clear()
self.pending_requests.update(restored_pending)

def _persist_pending_requests_to_session(self, session: AgentSession | None) -> None:
"""Persist pending request-info events to the session state."""
if session is None:
return

agent_state = session.state.setdefault(self._SESSION_STATE_KEY, {})
if not isinstance(agent_state, dict):
logger.warning(
"Skipping pending request persistence because '%s' is not a mapping.",
self._SESSION_STATE_KEY,
)
return

agent_state[self._PENDING_REQUESTS_STATE_KEY] = {
request_id: event.to_dict() for request_id, event in self.pending_requests.items()
}

def _process_pending_requests(self, input_messages: Sequence[Message]) -> dict[str, Any]:
"""Process pending requests by extracting function responses and updating state.

Expand Down
88 changes: 88 additions & 0 deletions python/packages/core/tests/workflow/test_workflow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,94 @@ async def test_end_to_end_request_info_handling(self):
# Verify cleanup - pending requests should be cleared after function response handling
assert len(agent.pending_requests) == 0

async def test_request_info_resume_after_session_restore_with_checkpoint(self):
"""Pending request metadata in AgentSession should resume the same request_id after restore."""
from agent_framework import InMemoryCheckpointStorage

simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False)
requesting_executor = RequestingExecutor(id="requester", streaming=False)
checkpoint_storage = InMemoryCheckpointStorage()

workflow = (
WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=checkpoint_storage)
.add_edge(simple_executor, requesting_executor)
.build()
)
agent = WorkflowAgent(workflow=workflow, name="Request Restore Test Agent")
session = AgentSession()

updates: list[AgentResponseUpdate] = []
async for update in agent.run("Start request", stream=True, session=session):
updates.append(update)

approval_update = next(
(
update
for update in updates
if any(content.type == "function_approval_request" for content in update.contents)
),
None,
)
assert approval_update is not None, "Should have received a request_info approval request"

function_call = next(content for content in approval_update.contents if content.type == "function_call")
approval_request = next(
content for content in approval_update.contents if content.type == "function_approval_request"
)
request_id = approval_request.id
assert request_id is not None
assert function_call.call_id == request_id

checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
checkpoint_with_request = next(
(checkpoint for checkpoint in checkpoints if request_id in checkpoint.pending_request_info_events),
None,
)
assert checkpoint_with_request is not None

serialized_session = session.to_dict()
workflow_agent_state = serialized_session["state"].get("workflow_agent", {})
pending_state = workflow_agent_state.get("pending_request_info_events", {})
assert request_id in pending_state

restored_session = AgentSession.from_dict(serialized_session)

restored_simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", streaming=False)
restored_requesting_executor = RequestingExecutor(id="requester", streaming=False)
restored_workflow = (
WorkflowBuilder(start_executor=restored_simple_executor, checkpoint_storage=checkpoint_storage)
.add_edge(restored_simple_executor, restored_requesting_executor)
.build()
)
restored_agent = WorkflowAgent(workflow=restored_workflow, name="Request Restore Test Agent")

response_args = WorkflowAgent.RequestInfoFunctionArgs(
request_id=request_id,
data="User provided answer",
).to_dict()
approval_response = Content.from_function_approval_response(
approved=True,
id=request_id,
function_call=Content.from_function_call(
call_id=request_id,
name=WorkflowAgent.REQUEST_INFO_FUNCTION_NAME,
arguments=response_args,
),
)
response_message = Message(role="user", contents=[approval_response])

continuation_result = await restored_agent.run(
response_message,
session=restored_session,
checkpoint_id=checkpoint_with_request.checkpoint_id,
checkpoint_storage=checkpoint_storage,
)

assert isinstance(continuation_result, AgentResponse)
response_texts = [message.text for message in continuation_result.messages if message.text]
assert any("Request completed with response: User provided answer" in text for text in response_texts)
assert len(restored_agent.pending_requests) == 0

def test_workflow_as_agent_method(self) -> None:
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
# Create a simple workflow
Expand Down