diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 9fbe8c5a7..042148e17 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -1172,11 +1172,6 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes transformCallbacks, hasHooks, "CopilotClient.ResumeSessionAsync"); - if (config.OnMcpAuthRequest is not null) - { - await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); - } - try { var (traceparent, tracestate) = TelemetryHelpers.GetTraceContext(); @@ -1259,6 +1254,11 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes session.SetCapabilities(response.Capabilities); session.SetOpenCanvases(response.OpenCanvases); + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } + await UpdateSessionOptionsForModeAsync(session, config, cancellationToken).ConfigureAwait(false); } catch (Exception ex) diff --git a/dotnet/test/E2E/SessionE2ETests.cs b/dotnet/test/E2E/SessionE2ETests.cs index 202436c04..bcfb46295 100644 --- a/dotnet/test/E2E/SessionE2ETests.cs +++ b/dotnet/test/E2E/SessionE2ETests.cs @@ -269,6 +269,33 @@ public async Task Should_Resume_A_Session_Using_A_New_Client() Assert.Contains("4", answer2!.Data.Content ?? string.Empty); } + [Fact] + public async Task Resumes_A_Persisted_Session_From_A_New_Client_When_An_Mcp_OAuth_Handler_Is_Configured() + { + static Task CancelMcpAuthAsync(McpAuthContext request) + => Task.FromResult(McpAuthResult.Cancel()); + + await using var session1 = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = CancelMcpAuthAsync, + }); + var sessionId = session1.SessionId; + + var answer = await session1.SendAndWaitAsync(new MessageOptions { Prompt = "What is 1+1?" }); + Assert.NotNull(answer); + Assert.Contains("2", answer!.Data.Content ?? string.Empty); + + using var newClient = Ctx.CreateClient(); + await using var session2 = await newClient.ResumeSessionAsync(sessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = CancelMcpAuthAsync, + }); + + Assert.Equal(sessionId, session2.SessionId); + } + [Fact] public async Task Should_Throw_Error_When_Resuming_Non_Existent_Session() { diff --git a/dotnet/test/Unit/ClientSessionLifetimeTests.cs b/dotnet/test/Unit/ClientSessionLifetimeTests.cs index e51a4b911..a028a6c7e 100644 --- a/dotnet/test/Unit/ClientSessionLifetimeTests.cs +++ b/dotnet/test/Unit/ClientSessionLifetimeTests.cs @@ -300,12 +300,12 @@ public async Task ResumeSessionAsync_Registers_McpAuth_Interest_Only_When_Handle Assert.Collection( server.Requests.Take(2), + request => Assert.Equal("session.resume", request.Method), request => { Assert.Equal("session.eventLog.registerInterest", request.Method); Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); - }, - request => Assert.Equal("session.resume", request.Method)); + }); } [Fact] diff --git a/go/client.go b/go/client.go index 9e2819047..9b47d8426 100644 --- a/go/client.go +++ b/go/client.go @@ -1150,17 +1150,6 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, c.sessionsMux.Lock() c.sessions[sessionID] = session c.sessionsMux.Unlock() - if config.OnMCPAuthRequest != nil { - if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ - "sessionId": sessionID, - "eventType": "mcp.oauth_required", - }); err != nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() - return nil, err - } - } if c.options.SessionFS != nil { if config.CreateSessionFSProvider == nil { @@ -1197,6 +1186,18 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, return nil, fmt.Errorf("failed to unmarshal response: %w", err) } + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ + "sessionId": sessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, err + } + } + session.workspacePath = response.WorkspacePath session.setCapabilities(response.Capabilities) session.setOpenCanvases(response.OpenCanvases) diff --git a/go/client_test.go b/go/client_test.go index c889ced8d..29396480e 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -1405,7 +1405,7 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { assertMCPAuthInterest(t, snapshot[1]) }) - t.Run("resume conditionally registers MCP OAuth interest before session resume", func(t *testing.T) { + t.Run("resume conditionally registers MCP OAuth interest after session resume", func(t *testing.T) { client, requests, cleanup := newInMemoryClient(t) defer cleanup() @@ -1434,13 +1434,13 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) { defer withAuth.Disconnect() snapshot := requests.snapshot() - if snapshot[0].Method != "session.eventLog.registerInterest" { - t.Fatalf("expected MCP auth interest before session.resume, got %s", snapshot[0].Method) + if snapshot[0].Method != "session.resume" { + t.Fatalf("expected session.resume before MCP auth interest, got %s", snapshot[0].Method) } - if snapshot[1].Method != "session.resume" { - t.Fatalf("expected session.resume after MCP auth interest, got %s", snapshot[1].Method) + if snapshot[1].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest after session.resume, got %s", snapshot[1].Method) } - assertMCPAuthInterest(t, snapshot[0]) + assertMCPAuthInterest(t, snapshot[1]) }) } diff --git a/go/internal/e2e/resume_mcp_oauth_e2e_test.go b/go/internal/e2e/resume_mcp_oauth_e2e_test.go new file mode 100644 index 000000000..db61f483a --- /dev/null +++ b/go/internal/e2e/resume_mcp_oauth_e2e_test.go @@ -0,0 +1,60 @@ +package e2e + +import ( + "strings" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +func TestResumeMCPOAuthE2E(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + t.Run("should resume a persisted session with mcp auth handler", func(t *testing.T) { + ctx.ConfigureForTest(t) + + mcpAuthHandler := func(copilot.MCPAuthRequest, copilot.MCPAuthInvocation) (*copilot.MCPAuthResult, error) { + return copilot.MCPAuthResultCancelled(), nil + } + + session1, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + OnMCPAuthRequest: mcpAuthHandler, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sessionID := session1.SessionID + + _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) + if err != nil { + t.Fatalf("Failed to get assistant message: %v", err) + } + if ad, ok := answer.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(ad.Content, "2") { + t.Errorf("Expected answer to contain '2', got %v", answer.Data) + } + + newClient := ctx.NewClient() + t.Cleanup(func() { newClient.ForceStop() }) + + session2, err := newClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + OnMCPAuthRequest: mcpAuthHandler, + }) + if err != nil { + t.Fatalf("Failed to resume session: %v", err) + } + + if session2.SessionID != sessionID { + t.Errorf("Expected resumed session ID to match, got %q vs %q", session2.SessionID, sessionID) + } + }) +} diff --git a/java/src/test/java/com/github/copilot/McpOAuthResumeE2ETest.java b/java/src/test/java/com/github/copilot/McpOAuthResumeE2ETest.java new file mode 100644 index 000000000..19c15ed59 --- /dev/null +++ b/java/src/test/java/com/github/copilot/McpOAuthResumeE2ETest.java @@ -0,0 +1,76 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import com.github.copilot.generated.AssistantMessageEvent; +import com.github.copilot.rpc.McpAuthResult; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ResumeSessionConfig; +import com.github.copilot.rpc.SessionConfig; + +class McpOAuthResumeE2ETest { + + private static final String SNAPSHOT = "resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured"; + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @Test + @Tag("isolated-resume") + void resumesAPersistedSessionFromANewClientWhenAnMcpOauthHandlerIsConfigured() throws Exception { + ctx.configureForTest("session", SNAPSHOT); + + String sessionId; + try (var client = ctx.createClient(); + var session = client + .createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get(30, TimeUnit.SECONDS)) { + sessionId = session.getSessionId(); + + AssistantMessageEvent response = session.sendAndWait(new MessageOptions().setPrompt("What is 1+1?"), 60_000) + .get(90, TimeUnit.SECONDS); + assertNotNull(response); + assertTrue(response.getData().content().contains("2"), + "Response should contain 2: " + response.getData().content()); + } + + try (var client = ctx.createClient(); + var session = client + .resumeSession(sessionId, + new ResumeSessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get(30, TimeUnit.SECONDS)) { + assertEquals(sessionId, session.getSessionId()); + } + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 613985103..129a5b73d 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -1575,12 +1575,6 @@ export class CopilotClient { } this.sessions.set(sessionId, session); this.setupSessionFs(session, config); - if (config.onMcpAuthRequest) { - await this.connection!.sendRequest("session.eventLog.registerInterest", { - sessionId, - eventType: "mcp.oauth_required", - }); - } const toolFilterOptions = this.resolveToolFilterOptions(config); @@ -1671,6 +1665,12 @@ export class CopilotClient { session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); session.setOpenCanvases(openCanvases ?? []); + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId, + eventType: "mcp.oauth_required", + }); + } await this.updateSessionOptionsForMode(session, config); } catch (e) { diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 07cd079df..18af68ac9 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -220,7 +220,7 @@ describe("CopilotClient", () => { ]); }); - it("registers MCP OAuth interest before resuming only when an auth handler is configured", async () => { + it("registers MCP OAuth interest after resuming only when an auth handler is configured", async () => { const client = new CopilotClient(); await client.start(); onTestFinished(() => client.forceStop()); @@ -240,12 +240,23 @@ describe("CopilotClient", () => { onMcpAuthRequest: () => ({ kind: "cancelled" }), }); - expect(spy.mock.calls[0]).toEqual([ - "session.eventLog.registerInterest", - { sessionId: "session-with-auth", eventType: "mcp.oauth_required" }, - ]); - expect(spy.mock.calls[1][0]).toBe("session.resume"); - expect(spy.mock.calls[1][1]).toEqual(expect.objectContaining({ requestPermission: true })); + // `session.eventLog.registerInterest` is session-scoped: the runtime only + // registers the session id while handling `session.resume`, so resume must + // be sent BEFORE registering interest. + const resumeIndex = spy.mock.calls.findIndex(([method]) => method === "session.resume"); + const interestIndex = spy.mock.calls.findIndex( + ([method]) => method === "session.eventLog.registerInterest" + ); + expect(resumeIndex).toBeGreaterThanOrEqual(0); + expect(interestIndex).toBeGreaterThanOrEqual(0); + expect(resumeIndex).toBeLessThan(interestIndex); + expect(spy.mock.calls[resumeIndex][1]).toEqual( + expect.objectContaining({ sessionId: "session-with-auth", requestPermission: true }) + ); + expect(spy.mock.calls[interestIndex][1]).toEqual({ + sessionId: "session-with-auth", + eventType: "mcp.oauth_required", + }); spy.mockClear(); await client.resumeSession("session-without-auth", { diff --git a/nodejs/test/e2e/session.e2e.test.ts b/nodejs/test/e2e/session.e2e.test.ts index 55a064ab4..c29f0790d 100644 --- a/nodejs/test/e2e/session.e2e.test.ts +++ b/nodejs/test/e2e/session.e2e.test.ts @@ -2,7 +2,7 @@ import { rm } from "fs/promises"; import { describe, expect, it, onTestFinished, vi } from "vitest"; import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js"; import { CopilotClient, approveAll, defineTool, RuntimeConnection } from "../../src/index.js"; -import { createSdkTestContext, isCI } from "./harness/sdkTestContext.js"; +import { createSdkTestContext, DEFAULT_GITHUB_TOKEN, isCI } from "./harness/sdkTestContext.js"; import { getFinalAssistantMessage, getNextEventOfType, retry } from "./harness/sdkTestHelper.js"; describe("Sessions", async () => { @@ -464,6 +464,40 @@ describe("Sessions", async () => { expect(session2.sessionId).toBe(sessionId); }); + it("resumes a persisted session from a new client when an MCP OAuth handler is configured", async () => { + // Take a turn so the session is persisted to the store and can be + // loaded by a different CLI process. + const session1 = await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + const sessionId = session1.sessionId; + const answer = await session1.sendAndWait({ prompt: "What is 1+1?" }); + expect(answer?.data.content).toContain("2"); + + // Resume from a fresh client (new CLI process). Its routing table does + // not know the session until it handles `session.resume`. Because an MCP + // OAuth handler is configured, the SDK issues a session-scoped + // `session.eventLog.registerInterest` for `mcp.oauth_required`; that must + // be sent AFTER `session.resume`, otherwise the runtime rejects it with + // "Session not found: ". + const newClient = new CopilotClient({ + env, + gitHubToken: isCI + ? DEFAULT_GITHUB_TOKEN + : (process.env.GITHUB_TOKEN ?? DEFAULT_GITHUB_TOKEN), + }); + onTestFinished(() => newClient.forceStop()); + + const session2 = await newClient.resumeSession(sessionId, { + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(session2.sessionId).toBe(sessionId); + await session2.disconnect(); + }); + it("should abort a session", async () => { const session = await client.createSession({ onPermissionRequest: approveAll }); diff --git a/python/copilot/client.py b/python/copilot/client.py index 7dade4440..cbe8f2c51 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -2754,11 +2754,6 @@ async def resume_session( session.on(on_event) with self._sessions_lock: self._sessions[session_id] = session - if on_mcp_auth_request is not None: - await self._client.request( - "session.eventLog.registerInterest", - {"sessionId": session_id, "eventType": "mcp.oauth_required"}, - ) log_timing( logger, logging.DEBUG, @@ -2788,6 +2783,11 @@ async def resume_session( session._set_open_canvases( [OpenCanvasInstance.from_dict(inst) for inst in open_canvases_raw] ) + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session.session_id, "eventType": "mcp.oauth_required"}, + ) except BaseException as exc: with self._sessions_lock: self._sessions.pop(session_id, None) diff --git a/python/e2e/test_session_e2e.py b/python/e2e/test_session_e2e.py index 2dc7bb1dc..22f99e403 100644 --- a/python/e2e/test_session_e2e.py +++ b/python/e2e/test_session_e2e.py @@ -11,7 +11,12 @@ from copilot.session_events import SessionModelChangeData from copilot.tools import Tool, ToolResult -from .testharness import E2ETestContext, get_final_assistant_message, get_next_event_of_type +from .testharness import ( + DEFAULT_GITHUB_TOKEN, + E2ETestContext, + get_final_assistant_message, + get_next_event_of_type, +) pytestmark = pytest.mark.asyncio(loop_scope="module") @@ -275,6 +280,40 @@ async def test_should_resume_a_session_using_a_new_client(self, ctx: E2ETestCont finally: await new_client.force_stop() + async def test_resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured( # noqa: E501 + self, ctx: E2ETestContext + ): + def on_mcp_auth_request(_request, _invocation): + return {"kind": "cancelled"} + + session1 = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=on_mcp_auth_request, + ) + session_id = session1.session_id + answer = await session1.send_and_wait("What is 1+1?") + assert answer is not None + assert "2" in answer.data.content + + github_token = DEFAULT_GITHUB_TOKEN if os.environ.get("GITHUB_ACTIONS") == "true" else None + new_client = CopilotClient( + connection=RuntimeConnection.for_stdio(path=ctx.cli_path), + working_directory=ctx.work_dir, + env=ctx.get_env(), + github_token=github_token, + ) + + try: + session2 = await new_client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=on_mcp_auth_request, + ) + assert session2.session_id == session_id + await session2.disconnect() + finally: + await new_client.force_stop() + async def test_should_throw_error_resuming_nonexistent_session(self, ctx: E2ETestContext): with pytest.raises(Exception): await ctx.client.resume_session( diff --git a/python/test_client.py b/python/test_client.py index db6703c3b..88a099e7a 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -230,7 +230,7 @@ async def mock_request(method, params, **kwargs): await client.force_stop() @pytest.mark.asyncio - async def test_mcp_auth_handler_registers_interest_before_resume(self): + async def test_mcp_auth_handler_registers_interest_after_resume(self): client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) await client.start() try: @@ -251,15 +251,15 @@ async def mock_request(method, params, **kwargs): on_mcp_auth_request=lambda request: {"kind": "cancelled"}, ) - interest_method, interest_payload = captured[0] - resume_method, resume_payload = captured[1] + resume_method, resume_payload = captured[0] + interest_method, interest_payload = captured[1] + assert resume_method == "session.resume" + assert resume_payload["requestPermission"] is True assert interest_method == "session.eventLog.registerInterest" assert interest_payload == { "sessionId": "session-with-auth", "eventType": "mcp.oauth_required", } - assert resume_method == "session.resume" - assert resume_payload["requestPermission"] is True finally: await client.force_stop() diff --git a/rust/src/session.rs b/rust/src/session.rs index b9f17f7de..e5a2dc4dd 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -1181,9 +1181,6 @@ impl Client { let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); - if has_mcp_auth_handler { - register_mcp_auth_interest(self, &session_id).await?; - } let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); let setup_start = Instant::now(); @@ -1253,6 +1250,9 @@ impl Client { }) .into()); } + if has_mcp_auth_handler { + register_mcp_auth_interest(self, &session_id).await?; + } // Reload skills after resume (best-effort). let skills_reload_start = Instant::now(); diff --git a/rust/tests/e2e/session.rs b/rust/tests/e2e/session.rs index ee3a010bf..744708db7 100644 --- a/rust/tests/e2e/session.rs +++ b/rust/tests/e2e/session.rs @@ -2,7 +2,9 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::handler::{ + ApproveAllHandler, McpAuthHandler, McpAuthRequest, McpAuthResult, +}; use github_copilot_sdk::session_events::{ SessionErrorData, SessionEventType, SessionInfoData, SessionModelChangeData, SessionResumeData, SessionStartData, SessionWarningData, UserMessageData, @@ -12,8 +14,8 @@ use github_copilot_sdk::types::LogLevel as SessionLogLevel; use github_copilot_sdk::{ Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange, AzureProviderOptions, DefaultAgentConfig, Error, GitHubReferenceType, LogOptions, - MessageOptions, ProviderConfig, ResumeSessionConfig, SectionOverride, SessionConfig, - SetModelOptions, SystemMessageConfig, Tool, ToolInvocation, ToolResult, + MessageOptions, ProviderConfig, RequestId, ResumeSessionConfig, SectionOverride, SessionConfig, + SessionId, SetModelOptions, SystemMessageConfig, Tool, ToolInvocation, ToolResult, }; use serde_json::json; @@ -597,6 +599,60 @@ async fn should_resume_a_session_using_a_new_client() { .await; } +#[tokio::test] +async fn resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured() { + with_e2e_context( + "session", + "resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .expect("create session"); + let session_id = session.id().clone(); + + let first = session + .send_and_wait("What is 1+1?") + .await + .expect("send") + .expect("assistant message"); + assert!(assistant_message_content(&first).contains('2')); + + session + .disconnect() + .await + .expect("disconnect first session"); + client.stop().await.expect("stop first client"); + + let new_client = ctx.start_client().await; + let resumed = new_client + .resume_session( + ResumeSessionConfig::new(session_id.clone()) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)) + .with_github_token(super::support::DEFAULT_TEST_TOKEN), + ) + .await + .expect("resume session"); + assert_eq!(resumed.id(), &session_id); + + resumed + .disconnect() + .await + .expect("disconnect resumed session"); + new_client.stop().await.expect("stop new client"); + }) + }, + ) + .await; +} + #[tokio::test] async fn should_receive_session_events() { with_e2e_context("session", "should_receive_session_events", |ctx| { @@ -1528,6 +1584,20 @@ async fn latest_user_message( .expect("user.message") } +struct CancelMcpAuthHandler; + +#[async_trait::async_trait] +impl McpAuthHandler for CancelMcpAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: McpAuthRequest, + ) -> McpAuthResult { + McpAuthResult::Cancelled + } +} + struct SecretNumberTool; #[async_trait::async_trait] diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 31b0cc233..d8ace9ad1 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -468,6 +468,11 @@ async fn resume_session_registers_mcp_auth_interest_only_with_handler() { } }); + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-with-auth").await; + let interest_req = read_framed(&mut server_read).await; assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); @@ -483,10 +488,6 @@ async fn resume_session_registers_mcp_auth_interest_only_with_handler() { ) .await; - let resume_req = read_framed(&mut server_read).await; - assert_eq!(resume_req["method"], "session.resume"); - assert_eq!(resume_req["params"]["requestPermission"], true); - server_respond_create(&mut server_write, &resume_req, "session-with-auth").await; respond_to_reload(&mut server_read, &mut server_write).await; let _session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); } diff --git a/test/snapshots/resume_mcp_oauth/should_resume_a_persisted_session_with_mcp_auth_handler.yaml b/test/snapshots/resume_mcp_oauth/should_resume_a_persisted_session_with_mcp_auth_handler.yaml new file mode 100644 index 000000000..250402101 --- /dev/null +++ b/test/snapshots/resume_mcp_oauth/should_resume_a_persisted_session_with_mcp_auth_handler.yaml @@ -0,0 +1,10 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1 + 1 = 2 diff --git a/test/snapshots/session/resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured.yaml b/test/snapshots/session/resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured.yaml new file mode 100644 index 000000000..250402101 --- /dev/null +++ b/test/snapshots/session/resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured.yaml @@ -0,0 +1,10 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1 + 1 = 2