Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ internal abstract class ProcessContentMetadataBase : GraphDataTypeBase
/// <param name="identifier">The unique identifier for the content.</param>
/// <param name="isTruncated">Indicates if the content is truncated.</param>
/// <param name="name">The name of the content.</param>
protected ProcessContentMetadataBase(ContentBase content, string identifier, bool isTruncated, string name) : base(ProcessConversationMetadataDataType)
/// <param name="correlationId">The correlation ID for the content.</param>
protected ProcessContentMetadataBase(ContentBase content, string identifier, bool isTruncated, string name, string correlationId) : base(ProcessConversationMetadataDataType)
{
this.Identifier = identifier;
this.IsTruncated = isTruncated;
this.Content = content;
this.Name = name;
this.CorrelationId = correlationId;
}

/// <summary>
Expand Down Expand Up @@ -55,7 +57,7 @@ protected ProcessContentMetadataBase(ContentBase content, string identifier, boo
/// Identifier to group multiple contents.
/// </summary>
[JsonPropertyName("correlationId")]
public string? CorrelationId { get; set; }
public string CorrelationId { get; set; }

/// <summary>
/// Gets or sets the sequenceNumber.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal sealed class ProcessConversationMetadata : ProcessContentMetadataBase
/// <summary>
/// Initializes a new instance of the <see cref="ProcessConversationMetadata"/> class.
/// </summary>
public ProcessConversationMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name) : base(contentBase, identifier, isTruncated, name)
public ProcessConversationMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name, string correlationId) : base(contentBase, identifier, isTruncated, name, correlationId)
{
this.DataType = ProcessConversationMetadataDataType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal sealed class ProcessFileMetadata : ProcessContentMetadataBase
/// <summary>
/// Initializes a new instance of the <see cref="ProcessFileMetadata"/> class.
/// </summary>
public ProcessFileMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name) : base(contentBase, identifier, isTruncated, name)
public ProcessFileMetadata(ContentBase contentBase, string identifier, bool isTruncated, string name, string correlationId) : base(contentBase, identifier, isTruncated, name, correlationId)
{
this.DataType = ProcessFileMetadataDataType;
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.Agents.AI.Purview/PurviewSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class PurviewSettings
/// <param name="appName">The publicly visible name of the application.</param>
public PurviewSettings(string appName)
{
this.AppName = appName;
this.AppName = string.IsNullOrWhiteSpace(appName) ? throw new ArgumentException("AppName cannot be null or whitespace.", nameof(appName)) : appName;
}

/// <summary>
Expand Down
25 changes: 20 additions & 5 deletions dotnet/src/Microsoft.Agents.AI.Purview/PurviewWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private static string GetSessionIdFromAgentSession(AgentSession? session, IEnume
}
}

return Guid.NewGuid().ToString();
return string.Empty;
}

/// <summary>
Expand Down Expand Up @@ -136,12 +136,15 @@ public async Task<ChatResponse> ProcessChatContentAsync(IEnumerable<ChatMessage>
/// <returns>The agent's response. This could be the response from the agent or a message indicating that Purview has blocked the prompt or response.</returns>
public async Task<AgentResponse> ProcessAgentContentAsync(IEnumerable<ChatMessage> messages, AgentSession? session, AgentRunOptions? options, AIAgent innerAgent, CancellationToken cancellationToken)
{
string sessionId = GetSessionIdFromAgentSession(session, messages);

string? resolvedUserId = null;

string sessionId = string.Empty;
try
{
sessionId = GetSessionIdFromAgentSession(session, messages);
if (string.IsNullOrEmpty(sessionId))
{
sessionId = Guid.NewGuid().ToString();
}
(bool shouldBlockPrompt, resolvedUserId) = await this._scopedProcessor.ProcessMessagesAsync(messages, sessionId, Activity.UploadText, this._purviewSettings, null, cancellationToken).ConfigureAwait(false);

if (shouldBlockPrompt)
Expand Down Expand Up @@ -171,7 +174,19 @@ public async Task<AgentResponse> ProcessAgentContentAsync(IEnumerable<ChatMessag

try
{
(bool shouldBlockResponse, _) = await this._scopedProcessor.ProcessMessagesAsync(response.Messages, sessionId, Activity.UploadText, this._purviewSettings, resolvedUserId, cancellationToken).ConfigureAwait(false);
string sessionIdResponse = GetSessionIdFromAgentSession(session, messages);
if (string.IsNullOrEmpty(sessionIdResponse))
{
if (string.IsNullOrEmpty(sessionId))
{
sessionIdResponse = Guid.NewGuid().ToString();
}
else
{
sessionIdResponse = sessionId;
}
}
(bool shouldBlockResponse, _) = await this._scopedProcessor.ProcessMessagesAsync(response.Messages, sessionIdResponse, Activity.UploadText, this._purviewSettings, resolvedUserId, cancellationToken).ConfigureAwait(false);

if (shouldBlockResponse)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ private async Task<List<ProcessContentRequest>> MapMessageToPCRequestsAsync(IEnu
{
string messageId = message.MessageId ?? Guid.NewGuid().ToString();
ContentBase content = new PurviewTextContent(message.Text);
ProcessConversationMetadata conversationmetadata = new(content, messageId, false, $"Agent Framework Message {messageId}")
string correlationId = (sessionId ?? Guid.NewGuid().ToString()) + "@AF";
ProcessConversationMetadata conversationMetadata = new(content, messageId, false, $"Agent Framework Message {messageId}", correlationId)
{
CorrelationId = sessionId ?? Guid.NewGuid().ToString()
SequenceNumber = DateTime.UtcNow.Ticks,
};
ActivityMetadata activityMetadata = new(activity);
PolicyLocation policyLocation;
Expand Down Expand Up @@ -162,7 +163,7 @@ private async Task<List<ProcessContentRequest>> MapMessageToPCRequestsAsync(IEnu
OperatingSystemVersion = "Unknown"
}
};
ContentToProcess contentToProcess = new([conversationmetadata], activityMetadata, deviceMetadata, integratedAppMetadata, protectedAppMetadata);
ContentToProcess contentToProcess = new([conversationMetadata], activityMetadata, deviceMetadata, integratedAppMetadata, protectedAppMetadata);

if (userId == null &&
tokenInfo?.UserId != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ private static ProcessContentRequest CreateValidProcessContentRequest()
private static ContentToProcess CreateValidContentToProcess()
{
var content = new PurviewTextContent("Test content");
var metadata = new ProcessConversationMetadata(content, "msg-123", false, "Test message");
var metadata = new ProcessConversationMetadata(content, "msg-123", false, "Test message", "test-correlation-id");
var activityMetadata = new ActivityMetadata(Activity.UploadText);
var deviceMetadata = new DeviceMetadata
{
Expand Down
34 changes: 31 additions & 3 deletions python/packages/purview/agent_framework_purview/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ def __init__(
self._processor = ScopedContentProcessor(self._client, settings, cache_provider)
self._settings = settings

@staticmethod
def _get_agent_session_id(context: AgentContext) -> str | None:
"""Resolve a session/conversation id from the agent run context.

Resolution order:
1. thread.service_thread_id
2. First message whose additional_properties contains 'conversation_id'
3. None: the downstream processor will generate a new UUID
"""
if context.thread and context.thread.service_thread_id:
return context.thread.service_thread_id

for message in context.messages:
conversation_id = message.additional_properties.get("conversation_id")
if conversation_id is not None:
return str(conversation_id)

return None

async def process(
self,
context: AgentContext,
Expand All @@ -53,8 +72,9 @@ async def process(
resolved_user_id: str | None = None
try:
# Pre (prompt) check
session_id = self._get_agent_session_id(context)
should_block_prompt, resolved_user_id = await self._processor.process_messages(
context.messages, Activity.UPLOAD_TEXT
context.messages, Activity.UPLOAD_TEXT, session_id=session_id
)
if should_block_prompt:
from agent_framework import AgentResponse, ChatMessage
Expand All @@ -79,10 +99,14 @@ async def process(
try:
# Post (response) check only if we have a normal AgentResponse
# Use the same user_id from the request for the response evaluation
session_id_response = self._get_agent_session_id(context)
if session_id_response is None:
session_id_response = session_id
if context.result and not context.stream:
should_block_response, _ = await self._processor.process_messages(
context.result.messages, # type: ignore[union-attr]
Activity.UPLOAD_TEXT,
session_id=session_id,
user_id=resolved_user_id,
)
if should_block_response:
Expand Down Expand Up @@ -144,8 +168,9 @@ async def process(
) -> None: # type: ignore[override]
resolved_user_id: str | None = None
try:
session_id = context.options.get("conversation_id") if context.options else None
should_block_prompt, resolved_user_id = await self._processor.process_messages(
context.messages, Activity.UPLOAD_TEXT
context.messages, Activity.UPLOAD_TEXT, session_id=session_id
)
if should_block_prompt:
from agent_framework import ChatMessage, ChatResponse
Expand All @@ -169,12 +194,15 @@ async def process(
try:
# Post (response) evaluation only if non-streaming and we have messages result shape
# Use the same user_id from the request for the response evaluation
session_id_response = context.options.get("conversation_id") if context.options else None
if session_id_response is None:
session_id_response = session_id
if context.result and not context.stream:
result_obj = context.result
messages = getattr(result_obj, "messages", None)
if messages:
should_block_response, _ = await self._processor.process_messages(
messages, Activity.UPLOAD_TEXT, user_id=resolved_user_id
messages, Activity.UPLOAD_TEXT, session_id=session_id_response, user_id=resolved_user_id
)
if should_block_response:
from agent_framework import ChatMessage, ChatResponse
Expand Down
26 changes: 20 additions & 6 deletions python/packages/purview/agent_framework_purview/_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import time
import uuid
from collections.abc import Iterable, MutableMapping
from typing import Any
Expand Down Expand Up @@ -62,21 +63,26 @@ def __init__(self, client: PurviewClient, settings: PurviewSettings, cache_provi
self._background_tasks: set[asyncio.Task[Any]] = set()

async def process_messages(
self, messages: Iterable[ChatMessage], activity: Activity, user_id: str | None = None
self,
messages: Iterable[ChatMessage],
activity: Activity,
session_id: str | None = None,
user_id: str | None = None,
) -> tuple[bool, str | None]:
"""Process messages for policy evaluation.
Args:
messages: The messages to process
activity: The activity type (e.g., UPLOAD_TEXT)
session_id: Optional session/conversation id. Else, a new GUID is generated.
user_id: Optional user_id to use for all messages. If provided, this is the fallback.
Returns:
A tuple of (should_block: bool, resolved_user_id: str | None).
The resolved_user_id can be stored and passed back when processing the response
to ensure the same user context is maintained throughout the request/response cycle.
"""
pc_requests, resolved_user_id = await self._map_messages(messages, activity, user_id)
pc_requests, resolved_user_id = await self._map_messages(messages, activity, session_id, user_id)
should_block = False
for req in pc_requests:
resp = await self._process_with_scopes(req)
Expand All @@ -90,13 +96,18 @@ async def process_messages(
return should_block, resolved_user_id

async def _map_messages(
self, messages: Iterable[ChatMessage], activity: Activity, provided_user_id: str | None = None
self,
messages: Iterable[ChatMessage],
activity: Activity,
session_id: str | None = None,
provided_user_id: str | None = None,
) -> tuple[list[ProcessContentRequest], str | None]:
"""Map messages to ProcessContentRequests.
Args:
messages: The messages to map
activity: The activity type
session_id: Optional session/conversation id to use for correlation
provided_user_id: Optional user_id to use. If provided, this is the fallback.
Returns:
Expand Down Expand Up @@ -137,12 +148,14 @@ async def _map_messages(
for m in messages:
message_id = m.message_id or str(uuid.uuid4())
content = PurviewTextContent(data=m.text or "")
correlation_id = (session_id or str(uuid.uuid4())) + "@AF"
meta = ProcessConversationMetadata(
identifier=message_id,
content=content,
name=f"Agent Framework Message {message_id}",
is_truncated=False,
correlation_id=str(uuid.uuid4()),
correlation_id=correlation_id,
sequence_number=time.time_ns(),
)
activity_meta = ActivityMetadata(activity=activity)

Expand All @@ -159,12 +172,13 @@ async def _map_messages(
else:
raise ValueError("App location not provided or inferable")

app_version = self._settings.app_version or "Unknown"
protected_app = ProtectedAppMetadata(
name=self._settings.app_name,
version="1.0",
version=app_version,
application_location=policy_location,
)
integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version="1.0")
integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version=app_version)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Unknown", operating_system_version="Unknown"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_policy_location(self) -> dict[str, str]:


class PurviewSettings(AFBaseSettings):
"""Settings for Purview integration mirroring .NET PurviewSettings.
"""Settings for Purview integration.
Attributes:
app_name: Public app name.
Expand Down
Loading
Loading