diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessContentMetadataBase.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessContentMetadataBase.cs
index a401288127..ee9978fdf2 100644
--- a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessContentMetadataBase.cs
+++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessContentMetadataBase.cs
@@ -21,12 +21,14 @@ internal abstract class ProcessContentMetadataBase : GraphDataTypeBase
/// The unique identifier for the content.
/// Indicates if the content is truncated.
/// The name of the content.
- protected ProcessContentMetadataBase(ContentBase content, string identifier, bool isTruncated, string name) : base(ProcessConversationMetadataDataType)
+ /// The correlation ID for the content.
+ protected ProcessContentMetadataBase(ContentBase content, string identifier, bool isTruncated, string name, string correlationId) : base(ProcessConversationMetadataDataType)
{
this.Identifier = identifier;
this.IsTruncated = isTruncated;
this.Content = content;
this.Name = name;
+ this.CorrelationId = correlationId;
}
///
@@ -55,7 +57,7 @@ protected ProcessContentMetadataBase(ContentBase content, string identifier, boo
/// Identifier to group multiple contents.
///
[JsonPropertyName("correlationId")]
- public string? CorrelationId { get; set; }
+ public string CorrelationId { get; set; }
///
/// Gets or sets the sequenceNumber.
diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessConversationMetadata.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessConversationMetadata.cs
index 86bedb9248..9100eac02e 100644
--- a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessConversationMetadata.cs
+++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessConversationMetadata.cs
@@ -15,7 +15,7 @@ internal sealed class ProcessConversationMetadata : ProcessContentMetadataBase
///
/// Initializes a new instance of the class.
///
- public ProcessConversationMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name) : base(contentBase, identifier, isTruncated, name)
+ public ProcessConversationMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name, string correlationId) : base(contentBase, identifier, isTruncated, name, correlationId)
{
this.DataType = ProcessConversationMetadataDataType;
}
diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessFileMetadata.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessFileMetadata.cs
index a9f1749bed..89c0912e09 100644
--- a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessFileMetadata.cs
+++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/ProcessFileMetadata.cs
@@ -14,7 +14,7 @@ internal sealed class ProcessFileMetadata : ProcessContentMetadataBase
///
/// Initializes a new instance of the class.
///
- public ProcessFileMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name) : base(contentBase, identifier, isTruncated, name)
+ public ProcessFileMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name, string correlationId) : base(contentBase, identifier, isTruncated, name, correlationId)
{
this.DataType = ProcessFileMetadataDataType;
}
diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewSettings.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewSettings.cs
index cb400805c6..508f531bbe 100644
--- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewSettings.cs
+++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewSettings.cs
@@ -19,7 +19,7 @@ public class PurviewSettings
/// The publicly visible name of the application.
public PurviewSettings(string appName)
{
- this.AppName = appName;
+ this.AppName = string.IsNullOrWhiteSpace(appName) ? throw new ArgumentException("AppName cannot be null or whitespace.", nameof(appName)) : appName;
}
///
diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs
index 5a63448478..9b6cdc2ffd 100644
--- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs
+++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs
@@ -53,7 +53,7 @@ private static string GetSessionIdFromAgentSession(AgentSession? session, IEnume
}
}
- return Guid.NewGuid().ToString();
+ return string.Empty;
}
///
@@ -136,12 +136,15 @@ public async Task ProcessChatContentAsync(IEnumerable
/// The agent's response. This could be the response from the agent or a message indicating that Purview has blocked the prompt or response.
public async Task ProcessAgentContentAsync(IEnumerable messages, AgentSession? session, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken)
{
- string sessionId = GetSessionIdFromAgentSession(session, messages);
-
string? resolvedUserId = null;
-
+ string sessionId = string.Empty;
try
{
+ sessionId = GetSessionIdFromAgentSession(session, messages);
+ if (string.IsNullOrEmpty(sessionId))
+ {
+ sessionId = Guid.NewGuid().ToString();
+ }
(bool shouldBlockPrompt, resolvedUserId) = await this._scopedProcessor.ProcessMessagesAsync(messages, sessionId, Activity.UploadText, this._purviewSettings, null, cancellationToken).ConfigureAwait(false);
if (shouldBlockPrompt)
@@ -171,7 +174,19 @@ public async Task ProcessAgentContentAsync(IEnumerable> MapMessageToPCRequestsAsync(IEnu
{
string messageId = message.MessageId ?? Guid.NewGuid().ToString();
ContentBase content = new PurviewTextContent(message.Text);
- ProcessConversationMetadata conversationmetadata = new(content, messageId, false, $"Agent Framework Message {messageId}")
+ string correlationId = (sessionId ?? Guid.NewGuid().ToString()) + "@AF";
+ ProcessConversationMetadata conversationMetadata = new(content, messageId, false, $"Agent Framework Message {messageId}", correlationId)
{
- CorrelationId = sessionId ?? Guid.NewGuid().ToString()
+ SequenceNumber = DateTime.UtcNow.Ticks,
};
ActivityMetadata activityMetadata = new(activity);
PolicyLocation policyLocation;
@@ -162,7 +163,7 @@ private async Task> MapMessageToPCRequestsAsync(IEnu
OperatingSystemVersion = "Unknown"
}
};
- ContentToProcess contentToProcess = new([conversationmetadata], activityMetadata, deviceMetadata, integratedAppMetadata, protectedAppMetadata);
+ ContentToProcess contentToProcess = new([conversationMetadata], activityMetadata, deviceMetadata, integratedAppMetadata, protectedAppMetadata);
if (userId == null &&
tokenInfo?.UserId != null)
diff --git a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs
index 0846decc2f..38abc903d3 100644
--- a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs
+++ b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs
@@ -478,7 +478,7 @@ private static ProcessContentRequest CreateValidProcessContentRequest()
private static ContentToProcess CreateValidContentToProcess()
{
var content = new PurviewTextContent("Test content");
- var metadata = new ProcessConversationMetadata(content, "msg-123", false, "Test message");
+ var metadata = new ProcessConversationMetadata(content, "msg-123", false, "Test message", "test-correlation-id");
var activityMetadata = new ActivityMetadata(Activity.UploadText);
var deviceMetadata = new DeviceMetadata
{
diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py
index 42f8b37df6..52a74ffc10 100644
--- a/python/packages/purview/agent_framework_purview/_middleware.py
+++ b/python/packages/purview/agent_framework_purview/_middleware.py
@@ -45,6 +45,25 @@ def __init__(
self._processor = ScopedContentProcessor(self._client, settings, cache_provider)
self._settings = settings
+ @staticmethod
+ def _get_agent_session_id(context: AgentContext) -> str | None:
+ """Resolve a session/conversation id from the agent run context.
+
+ Resolution order:
+ 1. thread.service_thread_id
+ 2. First message whose additional_properties contains 'conversation_id'
+ 3. None: the downstream processor will generate a new UUID
+ """
+ if context.thread and context.thread.service_thread_id:
+ return context.thread.service_thread_id
+
+ for message in context.messages:
+ conversation_id = message.additional_properties.get("conversation_id")
+ if conversation_id is not None:
+ return str(conversation_id)
+
+ return None
+
async def process(
self,
context: AgentContext,
@@ -53,8 +72,9 @@ async def process(
resolved_user_id: str | None = None
try:
# Pre (prompt) check
+ session_id = self._get_agent_session_id(context)
should_block_prompt, resolved_user_id = await self._processor.process_messages(
- context.messages, Activity.UPLOAD_TEXT
+ context.messages, Activity.UPLOAD_TEXT, session_id=session_id
)
if should_block_prompt:
from agent_framework import AgentResponse, ChatMessage
@@ -79,10 +99,14 @@ async def process(
try:
# Post (response) check only if we have a normal AgentResponse
# Use the same user_id from the request for the response evaluation
+ session_id_response = self._get_agent_session_id(context)
+ if session_id_response is None:
+ session_id_response = session_id
if context.result and not context.stream:
should_block_response, _ = await self._processor.process_messages(
context.result.messages, # type: ignore[union-attr]
Activity.UPLOAD_TEXT,
+ session_id=session_id,
user_id=resolved_user_id,
)
if should_block_response:
@@ -144,8 +168,9 @@ async def process(
) -> None: # type: ignore[override]
resolved_user_id: str | None = None
try:
+ session_id = context.options.get("conversation_id") if context.options else None
should_block_prompt, resolved_user_id = await self._processor.process_messages(
- context.messages, Activity.UPLOAD_TEXT
+ context.messages, Activity.UPLOAD_TEXT, session_id=session_id
)
if should_block_prompt:
from agent_framework import ChatMessage, ChatResponse
@@ -169,12 +194,15 @@ async def process(
try:
# Post (response) evaluation only if non-streaming and we have messages result shape
# Use the same user_id from the request for the response evaluation
+ session_id_response = context.options.get("conversation_id") if context.options else None
+ if session_id_response is None:
+ session_id_response = session_id
if context.result and not context.stream:
result_obj = context.result
messages = getattr(result_obj, "messages", None)
if messages:
should_block_response, _ = await self._processor.process_messages(
- messages, Activity.UPLOAD_TEXT, user_id=resolved_user_id
+ messages, Activity.UPLOAD_TEXT, session_id=session_id_response, user_id=resolved_user_id
)
if should_block_response:
from agent_framework import ChatMessage, ChatResponse
diff --git a/python/packages/purview/agent_framework_purview/_processor.py b/python/packages/purview/agent_framework_purview/_processor.py
index fb115783f5..e2206a781b 100644
--- a/python/packages/purview/agent_framework_purview/_processor.py
+++ b/python/packages/purview/agent_framework_purview/_processor.py
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
+import time
import uuid
from collections.abc import Iterable, MutableMapping
from typing import Any
@@ -62,13 +63,18 @@ def __init__(self, client: PurviewClient, settings: PurviewSettings, cache_provi
self._background_tasks: set[asyncio.Task[Any]] = set()
async def process_messages(
- self, messages: Iterable[ChatMessage], activity: Activity, user_id: str | None = None
+ self,
+ messages: Iterable[ChatMessage],
+ activity: Activity,
+ session_id: str | None = None,
+ user_id: str | None = None,
) -> tuple[bool, str | None]:
"""Process messages for policy evaluation.
Args:
messages: The messages to process
activity: The activity type (e.g., UPLOAD_TEXT)
+ session_id: Optional session/conversation id. Else, a new GUID is generated.
user_id: Optional user_id to use for all messages. If provided, this is the fallback.
Returns:
@@ -76,7 +82,7 @@ async def process_messages(
The resolved_user_id can be stored and passed back when processing the response
to ensure the same user context is maintained throughout the request/response cycle.
"""
- pc_requests, resolved_user_id = await self._map_messages(messages, activity, user_id)
+ pc_requests, resolved_user_id = await self._map_messages(messages, activity, session_id, user_id)
should_block = False
for req in pc_requests:
resp = await self._process_with_scopes(req)
@@ -90,13 +96,18 @@ async def process_messages(
return should_block, resolved_user_id
async def _map_messages(
- self, messages: Iterable[ChatMessage], activity: Activity, provided_user_id: str | None = None
+ self,
+ messages: Iterable[ChatMessage],
+ activity: Activity,
+ session_id: str | None = None,
+ provided_user_id: str | None = None,
) -> tuple[list[ProcessContentRequest], str | None]:
"""Map messages to ProcessContentRequests.
Args:
messages: The messages to map
activity: The activity type
+ session_id: Optional session/conversation id to use for correlation
provided_user_id: Optional user_id to use. If provided, this is the fallback.
Returns:
@@ -137,12 +148,14 @@ async def _map_messages(
for m in messages:
message_id = m.message_id or str(uuid.uuid4())
content = PurviewTextContent(data=m.text or "")
+ correlation_id = (session_id or str(uuid.uuid4())) + "@AF"
meta = ProcessConversationMetadata(
identifier=message_id,
content=content,
name=f"Agent Framework Message {message_id}",
is_truncated=False,
- correlation_id=str(uuid.uuid4()),
+ correlation_id=correlation_id,
+ sequence_number=time.time_ns(),
)
activity_meta = ActivityMetadata(activity=activity)
@@ -159,12 +172,13 @@ async def _map_messages(
else:
raise ValueError("App location not provided or inferable")
+ app_version = self._settings.app_version or "Unknown"
protected_app = ProtectedAppMetadata(
name=self._settings.app_name,
- version="1.0",
+ version=app_version,
application_location=policy_location,
)
- integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version="1.0")
+ integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version=app_version)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Unknown", operating_system_version="Unknown"
diff --git a/python/packages/purview/agent_framework_purview/_settings.py b/python/packages/purview/agent_framework_purview/_settings.py
index 529b1399aa..3710d9de52 100644
--- a/python/packages/purview/agent_framework_purview/_settings.py
+++ b/python/packages/purview/agent_framework_purview/_settings.py
@@ -35,7 +35,7 @@ def get_policy_location(self) -> dict[str, str]:
class PurviewSettings(AFBaseSettings):
- """Settings for Purview integration mirroring .NET PurviewSettings.
+ """Settings for Purview integration.
Attributes:
app_name: Public app name.
diff --git a/python/packages/purview/tests/test_chat_middleware.py b/python/packages/purview/tests/test_chat_middleware.py
index d42c5a85a9..41ed8e0e4e 100644
--- a/python/packages/purview/tests/test_chat_middleware.py
+++ b/python/packages/purview/tests/test_chat_middleware.py
@@ -9,6 +9,7 @@
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings
+from agent_framework_purview._models import Activity
@dataclass
@@ -82,7 +83,7 @@ async def mock_next(ctx: ChatContext) -> None: # should not run
async def test_blocks_response(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
call_state = {"count": 0}
- async def side_effect(messages, activity, user_id=None):
+ async def side_effect(messages, activity, session_id=None, user_id=None):
call_state["count"] += 1
should_block = call_state["count"] == 2
return (should_block, "user-123")
@@ -157,7 +158,7 @@ async def test_chat_middleware_uses_consistent_user_id(
"""Test that the same user_id from pre-check is used in post-check."""
captured_user_ids = []
- async def mock_process_messages(messages, activity, user_id=None):
+ async def mock_process_messages(messages, activity, session_id=None, user_id=None):
captured_user_ids.append(user_id)
return (False, "resolved-user-123")
@@ -362,3 +363,67 @@ async def mock_next(ctx: ChatContext) -> None:
with pytest.raises(ValueError, match="post"):
await middleware.process(context, mock_next)
+
+ async def test_chat_middleware_uses_conversation_id_from_options(
+ self, middleware: PurviewChatPolicyMiddleware
+ ) -> None:
+ """Test that session_id is extracted from context.options['conversation_id']."""
+ chat_client = DummyChatClient()
+ messages = [ChatMessage(role="user", text="Hello")]
+ options = {"conversation_id": "conv-123", "model": "test-model"}
+ context = ChatContext(chat_client=chat_client, messages=messages, options=options)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: ChatContext) -> None:
+ result = MagicMock()
+ result.messages = [ChatMessage(role="assistant", text="Hi")]
+ ctx.result = result
+
+ await middleware.process(context, mock_next)
+
+ # Verify session_id is passed to both pre-check and post-check
+ assert mock_proc.call_count == 2
+ mock_proc.assert_any_call(messages, Activity.UPLOAD_TEXT, session_id="conv-123")
+
+ async def test_chat_middleware_passes_none_session_id_when_options_missing(
+ self, middleware: PurviewChatPolicyMiddleware
+ ) -> None:
+ """Test that session_id is None when options don't contain conversation_id."""
+ chat_client = DummyChatClient()
+ messages = [ChatMessage(role="user", text="Hello")]
+ context = ChatContext(chat_client=chat_client, messages=messages, options=None)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: ChatContext) -> None:
+ result = MagicMock()
+ result.messages = [ChatMessage(role="assistant", text="Hi")]
+ ctx.result = result
+
+ await middleware.process(context, mock_next)
+
+ # Verify session_id=None is passed
+ mock_proc.assert_any_call(messages, Activity.UPLOAD_TEXT, session_id=None)
+
+ async def test_chat_middleware_session_id_used_in_post_check(self, middleware: PurviewChatPolicyMiddleware) -> None:
+ """Test that session_id is passed to post-check process_messages call."""
+ chat_client = DummyChatClient()
+ messages = [ChatMessage(role="user", text="Hello")]
+ options = {"conversation_id": "conv-999"}
+ context = ChatContext(chat_client=chat_client, messages=messages, options=options)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: ChatContext) -> None:
+ result = MagicMock()
+ result.messages = [ChatMessage(role="assistant", text="Response")]
+ ctx.result = result
+
+ await middleware.process(context, mock_next)
+
+ # Verify both calls include session_id
+ assert mock_proc.call_count == 2
+ # Check post-check call includes session_id
+ post_check_call = mock_proc.call_args_list[1]
+ assert post_check_call[1]["session_id"] == "conv-999"
diff --git a/python/packages/purview/tests/test_middleware.py b/python/packages/purview/tests/test_middleware.py
index b0aadd8cd5..71eaa93056 100644
--- a/python/packages/purview/tests/test_middleware.py
+++ b/python/packages/purview/tests/test_middleware.py
@@ -5,10 +5,11 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-from agent_framework import AgentContext, AgentResponse, ChatMessage, MiddlewareTermination
+from agent_framework import AgentContext, AgentResponse, AgentThread, ChatMessage, MiddlewareTermination
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings
+from agent_framework_purview._models import Activity
class TestPurviewPolicyMiddleware:
@@ -92,7 +93,7 @@ async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddlew
call_count = 0
- async def mock_process_messages(messages, activity, user_id=None):
+ async def mock_process_messages(messages, activity, session_id=None, user_id=None):
nonlocal call_count
call_count += 1
should_block = call_count != 1
@@ -335,3 +336,93 @@ async def mock_next(ctx):
# Should raise the exception
with pytest.raises(ValueError, match="Test error"):
await middleware.process(context, mock_next)
+
+ async def test_middleware_uses_thread_service_thread_id_as_session_id(
+ self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
+ ) -> None:
+ """Test that session_id is extracted from thread.service_thread_id."""
+ thread = AgentThread(service_thread_id="thread-123")
+ context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")], thread=thread)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: AgentContext) -> None:
+ ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Hi")])
+
+ await middleware.process(context, mock_next)
+
+ # Verify session_id is passed to both pre-check and post-check
+ assert mock_proc.call_count == 2
+ mock_proc.assert_any_call(context.messages, Activity.UPLOAD_TEXT, session_id="thread-123")
+
+ async def test_middleware_uses_message_conversation_id_as_session_id(
+ self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
+ ) -> None:
+ """Test that session_id is extracted from message.additional_properties['conversation_id']."""
+ messages = [ChatMessage(role="user", text="Hello", additional_properties={"conversation_id": "conv-456"})]
+ context = AgentContext(agent=mock_agent, messages=messages)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: AgentContext) -> None:
+ ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Hi")])
+
+ await middleware.process(context, mock_next)
+
+ # Verify session_id is passed to both pre-check and post-check
+ assert mock_proc.call_count == 2
+ mock_proc.assert_any_call(messages, Activity.UPLOAD_TEXT, session_id="conv-456")
+
+ async def test_middleware_thread_id_takes_precedence_over_message_conversation_id(
+ self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
+ ) -> None:
+ """Test that thread.service_thread_id takes precedence over message conversation_id."""
+ thread = AgentThread(service_thread_id="thread-789")
+ messages = [ChatMessage(role="user", text="Hello", additional_properties={"conversation_id": "conv-456"})]
+ context = AgentContext(agent=mock_agent, messages=messages, thread=thread)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: AgentContext) -> None:
+ ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Hi")])
+
+ await middleware.process(context, mock_next)
+
+ # Verify thread ID is used, not message conversation_id
+ mock_proc.assert_any_call(messages, Activity.UPLOAD_TEXT, session_id="thread-789")
+
+ async def test_middleware_passes_none_session_id_when_not_available(
+ self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
+ ) -> None:
+ """Test that session_id is None when no thread or conversation_id is available."""
+ context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")])
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: AgentContext) -> None:
+ ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Hi")])
+
+ await middleware.process(context, mock_next)
+
+ # Verify session_id=None is passed
+ mock_proc.assert_any_call(context.messages, Activity.UPLOAD_TEXT, session_id=None)
+
+ async def test_middleware_session_id_used_in_post_check(
+ self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
+ ) -> None:
+ """Test that session_id is passed to post-check process_messages call."""
+ thread = AgentThread(service_thread_id="thread-999")
+ context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")], thread=thread)
+
+ with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
+
+ async def mock_next(ctx: AgentContext) -> None:
+ ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")])
+
+ await middleware.process(context, mock_next)
+
+ # Verify both calls include session_id
+ assert mock_proc.call_count == 2
+ # Check post-check call includes session_id
+ post_check_call = mock_proc.call_args_list[1]
+ assert post_check_call[1]["session_id"] == "thread-999"
diff --git a/python/packages/purview/tests/test_processor.py b/python/packages/purview/tests/test_processor.py
index f122c6e059..be4d4aca89 100644
--- a/python/packages/purview/tests/test_processor.py
+++ b/python/packages/purview/tests/test_processor.py
@@ -92,7 +92,7 @@ async def test_process_messages_with_defaults(self, processor: ScopedContentProc
assert should_block is False
assert user_id is None
- mock_map.assert_called_once_with(messages, Activity.UPLOAD_TEXT, None)
+ mock_map.assert_called_once_with(messages, Activity.UPLOAD_TEXT, None, None)
async def test_process_messages_blocks_content(
self, processor: ScopedContentProcessor, process_content_request_factory