diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index bbe7d153f..03e00af8d 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -18,6 +18,18 @@ public async Task Session_TracksActivities() var activities = new List(); var clientToServerLog = new List(); + // Predicate for the expected server tool-call activity, including all required tags. + // Defined here so it can be reused for both the wait and the assertion below. + Func isExpectedServerToolCall = a => + a.DisplayName == "tools/call DoubleValue" && + a.Kind == ActivityKind.Server && + a.Status == ActivityStatusCode.Unset && + a.Tags.Any(t => t.Key == "gen_ai.tool.name" && t.Value == "DoubleValue") && + a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && + a.Tags.Any(t => t.Key == "gen_ai.operation.name" && t.Value == "execute_tool") && + a.Tags.Any(t => t.Key == "mcp.protocol.version" && !string.IsNullOrEmpty(t.Value)) && + a.Tags.Any(t => t.Key == "mcp.session.id" && !string.IsNullOrEmpty(t.Value)); + using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() .AddSource("Experimental.ModelContextProtocol") .AddInMemoryExporter(activities) @@ -36,9 +48,11 @@ await RunConnected(async (client, server) => // Wait for server-side activities to be exported. The server processes messages // via fire-and-forget tasks, so activities may not be immediately available // after the client operation completes. Wait for the specific activity we need - // rather than a count, as other server activities may be exported first. - await WaitForAsync(() => activities.Any(a => - a.DisplayName == "tools/call DoubleValue" && a.Kind == ActivityKind.Server)); + // (including required tags) rather than just the display name, so that we don't + // assert before all tags have been populated. + await WaitForAsync( + () => activities.Any(isExpectedServerToolCall), + failureMessage: "Timed out waiting for the expected server tool-call activity (tools/call DoubleValue) to be exported with required tags."); } Assert.NotEmpty(activities); @@ -54,13 +68,7 @@ await WaitForAsync(() => activities.Any(a => // Per semantic conventions: mcp.protocol.version should be present after initialization Assert.Contains(clientToolCall.Tags, t => t.Key == "mcp.protocol.version" && !string.IsNullOrEmpty(t.Value)); - var serverToolCall = Assert.Single(activities, a => - a.Tags.Any(t => t.Key == "gen_ai.tool.name" && t.Value == "DoubleValue") && - a.Tags.Any(t => t.Key == "mcp.method.name" && t.Value == "tools/call") && - a.Tags.Any(t => t.Key == "gen_ai.operation.name" && t.Value == "execute_tool") && - a.DisplayName == "tools/call DoubleValue" && - a.Kind == ActivityKind.Server && - a.Status == ActivityStatusCode.Unset); + var serverToolCall = Assert.Single(activities, a => isExpectedServerToolCall(a)); // Per semantic conventions: mcp.protocol.version should be present after initialization Assert.Contains(serverToolCall.Tags, t => t.Key == "mcp.protocol.version" && !string.IsNullOrEmpty(t.Value)); @@ -245,12 +253,19 @@ private static async Task RunConnected(Func action, await serverTask; } - private static async Task WaitForAsync(Func condition, int timeoutMs = 10_000) + private static async Task WaitForAsync(Func condition, int timeoutMs = 10_000, string? failureMessage = null) { using var cts = new CancellationTokenSource(timeoutMs); - while (!condition()) + try + { + while (!condition()) + { + await Task.Delay(10, cts.Token); + } + } + catch (TaskCanceledException) { - await Task.Delay(10, cts.Token); + throw new Xunit.Sdk.XunitException(failureMessage ?? $"Condition was not met within {timeoutMs}ms."); } } }