Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1172,11 +1172,6 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes
transformCallbacks,
hasHooks,
"CopilotClient.ResumeSessionAsync");
if (config.OnMcpAuthRequest is not null)
{
await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken);
}

try
{
var (traceparent, tracestate) = TelemetryHelpers.GetTraceContext();
Expand Down Expand Up @@ -1259,6 +1254,11 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes
session.SetCapabilities(response.Capabilities);
session.SetOpenCanvases(response.OpenCanvases);

if (config.OnMcpAuthRequest is not null)
{
await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken);
}

await UpdateSessionOptionsForModeAsync(session, config, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
Expand Down
27 changes: 27 additions & 0 deletions dotnet/test/E2E/SessionE2ETests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,33 @@ public async Task Should_Resume_A_Session_Using_A_New_Client()
Assert.Contains("4", answer2!.Data.Content ?? string.Empty);
}

[Fact]
public async Task Resumes_A_Persisted_Session_From_A_New_Client_When_An_Mcp_OAuth_Handler_Is_Configured()
{
static Task<McpAuthResult?> CancelMcpAuthAsync(McpAuthContext request)
=> Task.FromResult<McpAuthResult?>(McpAuthResult.Cancel());

await using var session1 = await CreateSessionAsync(new SessionConfig
{
OnPermissionRequest = PermissionHandler.ApproveAll,
OnMcpAuthRequest = CancelMcpAuthAsync,
});
var sessionId = session1.SessionId;

var answer = await session1.SendAndWaitAsync(new MessageOptions { Prompt = "What is 1+1?" });
Assert.NotNull(answer);
Assert.Contains("2", answer!.Data.Content ?? string.Empty);

using var newClient = Ctx.CreateClient();
await using var session2 = await newClient.ResumeSessionAsync(sessionId, new ResumeSessionConfig
{
OnPermissionRequest = PermissionHandler.ApproveAll,
OnMcpAuthRequest = CancelMcpAuthAsync,
});

Assert.Equal(sessionId, session2.SessionId);
}

[Fact]
public async Task Should_Throw_Error_When_Resuming_Non_Existent_Session()
{
Expand Down
4 changes: 2 additions & 2 deletions dotnet/test/Unit/ClientSessionLifetimeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ public async Task ResumeSessionAsync_Registers_McpAuth_Interest_Only_When_Handle

Assert.Collection(
server.Requests.Take(2),
request => Assert.Equal("session.resume", request.Method),
request =>
{
Assert.Equal("session.eventLog.registerInterest", request.Method);
Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString());
},
request => Assert.Equal("session.resume", request.Method));
});
}

[Fact]
Expand Down
23 changes: 12 additions & 11 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,17 +1150,6 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
c.sessionsMux.Lock()
c.sessions[sessionID] = session
c.sessionsMux.Unlock()
if config.OnMCPAuthRequest != nil {
if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{
"sessionId": sessionID,
"eventType": "mcp.oauth_required",
}); err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
return nil, err
}
}

if c.options.SessionFS != nil {
if config.CreateSessionFSProvider == nil {
Expand Down Expand Up @@ -1197,6 +1186,18 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}

if config.OnMCPAuthRequest != nil {
if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{
"sessionId": sessionID,
"eventType": "mcp.oauth_required",
}); err != nil {
c.sessionsMux.Lock()
delete(c.sessions, sessionID)
c.sessionsMux.Unlock()
return nil, err
}
}

session.workspacePath = response.WorkspacePath
session.setCapabilities(response.Capabilities)
session.setOpenCanvases(response.OpenCanvases)
Expand Down
12 changes: 6 additions & 6 deletions go/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,7 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) {
assertMCPAuthInterest(t, snapshot[1])
})

t.Run("resume conditionally registers MCP OAuth interest before session resume", func(t *testing.T) {
t.Run("resume conditionally registers MCP OAuth interest after session resume", func(t *testing.T) {
client, requests, cleanup := newInMemoryClient(t)
defer cleanup()

Expand Down Expand Up @@ -1434,13 +1434,13 @@ func TestClient_MCPAuthInterestRegistration(t *testing.T) {
defer withAuth.Disconnect()

snapshot := requests.snapshot()
if snapshot[0].Method != "session.eventLog.registerInterest" {
t.Fatalf("expected MCP auth interest before session.resume, got %s", snapshot[0].Method)
if snapshot[0].Method != "session.resume" {
t.Fatalf("expected session.resume before MCP auth interest, got %s", snapshot[0].Method)
}
if snapshot[1].Method != "session.resume" {
t.Fatalf("expected session.resume after MCP auth interest, got %s", snapshot[1].Method)
if snapshot[1].Method != "session.eventLog.registerInterest" {
t.Fatalf("expected MCP auth interest after session.resume, got %s", snapshot[1].Method)
}
assertMCPAuthInterest(t, snapshot[0])
assertMCPAuthInterest(t, snapshot[1])
})
}

Expand Down
60 changes: 60 additions & 0 deletions go/internal/e2e/resume_mcp_oauth_e2e_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package e2e

import (
"strings"
"testing"

copilot "github.com/github/copilot-sdk/go"
"github.com/github/copilot-sdk/go/internal/e2e/testharness"
)

func TestResumeMCPOAuthE2E(t *testing.T) {
ctx := testharness.NewTestContext(t)
client := ctx.NewClient()
t.Cleanup(func() { client.ForceStop() })

t.Run("should resume a persisted session with mcp auth handler", func(t *testing.T) {
ctx.ConfigureForTest(t)

mcpAuthHandler := func(copilot.MCPAuthRequest, copilot.MCPAuthInvocation) (*copilot.MCPAuthResult, error) {
return copilot.MCPAuthResultCancelled(), nil
}

session1, err := client.CreateSession(t.Context(), &copilot.SessionConfig{
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
OnMCPAuthRequest: mcpAuthHandler,
})
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
sessionID := session1.SessionID

_, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1)
if err != nil {
t.Fatalf("Failed to get assistant message: %v", err)
}
if ad, ok := answer.Data.(*copilot.AssistantMessageData); !ok || !strings.Contains(ad.Content, "2") {
t.Errorf("Expected answer to contain '2', got %v", answer.Data)
}

newClient := ctx.NewClient()
t.Cleanup(func() { newClient.ForceStop() })

session2, err := newClient.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{
OnPermissionRequest: copilot.PermissionHandler.ApproveAll,
OnMCPAuthRequest: mcpAuthHandler,
})
if err != nil {
t.Fatalf("Failed to resume session: %v", err)
}

if session2.SessionID != sessionID {
t.Errorf("Expected resumed session ID to match, got %q vs %q", session2.SessionID, sessionID)
}
})
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/

package com.github.copilot;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;

import com.github.copilot.generated.AssistantMessageEvent;
import com.github.copilot.rpc.McpAuthResult;
import com.github.copilot.rpc.MessageOptions;
import com.github.copilot.rpc.PermissionHandler;
import com.github.copilot.rpc.ResumeSessionConfig;
import com.github.copilot.rpc.SessionConfig;

class McpOAuthResumeE2ETest {

private static final String SNAPSHOT = "resumes_a_persisted_session_from_a_new_client_when_an_mcp_oauth_handler_is_configured";

private static E2ETestContext ctx;

@BeforeAll
static void setup() throws Exception {
ctx = E2ETestContext.create();
}

@AfterAll
static void teardown() throws Exception {
if (ctx != null) {
ctx.close();
}
}

@Test
@Tag("isolated-resume")
void resumesAPersistedSessionFromANewClientWhenAnMcpOauthHandlerIsConfigured() throws Exception {
ctx.configureForTest("session", SNAPSHOT);

String sessionId;
try (var client = ctx.createClient();
var session = client
.createSession(
new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)
.setOnMcpAuthRequest((request, invocation) -> CompletableFuture
.completedFuture(McpAuthResult.cancelled())))
.get(30, TimeUnit.SECONDS)) {
sessionId = session.getSessionId();

AssistantMessageEvent response = session.sendAndWait(new MessageOptions().setPrompt("What is 1+1?"), 60_000)
.get(90, TimeUnit.SECONDS);
assertNotNull(response);
assertTrue(response.getData().content().contains("2"),
"Response should contain 2: " + response.getData().content());
}

try (var client = ctx.createClient();
var session = client
.resumeSession(sessionId,
new ResumeSessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)
.setOnMcpAuthRequest((request, invocation) -> CompletableFuture
.completedFuture(McpAuthResult.cancelled())))
.get(30, TimeUnit.SECONDS)) {
assertEquals(sessionId, session.getSessionId());
}
}
}
12 changes: 6 additions & 6 deletions nodejs/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1575,12 +1575,6 @@ export class CopilotClient {
}
this.sessions.set(sessionId, session);
this.setupSessionFs(session, config);
if (config.onMcpAuthRequest) {
await this.connection!.sendRequest("session.eventLog.registerInterest", {
sessionId,
eventType: "mcp.oauth_required",
});
}

const toolFilterOptions = this.resolveToolFilterOptions(config);

Expand Down Expand Up @@ -1671,6 +1665,12 @@ export class CopilotClient {
session["_workspacePath"] = workspacePath;
session.setCapabilities(capabilities);
session.setOpenCanvases(openCanvases ?? []);
if (config.onMcpAuthRequest) {
Comment thread
MackinnonBuck marked this conversation as resolved.
await this.connection!.sendRequest("session.eventLog.registerInterest", {
sessionId,
eventType: "mcp.oauth_required",
});
}

await this.updateSessionOptionsForMode(session, config);
} catch (e) {
Expand Down
25 changes: 18 additions & 7 deletions nodejs/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ describe("CopilotClient", () => {
]);
});

it("registers MCP OAuth interest before resuming only when an auth handler is configured", async () => {
it("registers MCP OAuth interest after resuming only when an auth handler is configured", async () => {
const client = new CopilotClient();
await client.start();
onTestFinished(() => client.forceStop());
Expand All @@ -240,12 +240,23 @@ describe("CopilotClient", () => {
onMcpAuthRequest: () => ({ kind: "cancelled" }),
});

expect(spy.mock.calls[0]).toEqual([
"session.eventLog.registerInterest",
{ sessionId: "session-with-auth", eventType: "mcp.oauth_required" },
]);
expect(spy.mock.calls[1][0]).toBe("session.resume");
expect(spy.mock.calls[1][1]).toEqual(expect.objectContaining({ requestPermission: true }));
// `session.eventLog.registerInterest` is session-scoped: the runtime only
// registers the session id while handling `session.resume`, so resume must
// be sent BEFORE registering interest.
const resumeIndex = spy.mock.calls.findIndex(([method]) => method === "session.resume");
const interestIndex = spy.mock.calls.findIndex(
([method]) => method === "session.eventLog.registerInterest"
);
expect(resumeIndex).toBeGreaterThanOrEqual(0);
expect(interestIndex).toBeGreaterThanOrEqual(0);
expect(resumeIndex).toBeLessThan(interestIndex);
expect(spy.mock.calls[resumeIndex][1]).toEqual(
expect.objectContaining({ sessionId: "session-with-auth", requestPermission: true })
);
expect(spy.mock.calls[interestIndex][1]).toEqual({
sessionId: "session-with-auth",
eventType: "mcp.oauth_required",
});

spy.mockClear();
await client.resumeSession("session-without-auth", {
Expand Down
36 changes: 35 additions & 1 deletion nodejs/test/e2e/session.e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { rm } from "fs/promises";
import { describe, expect, it, onTestFinished, vi } from "vitest";
import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js";
import { CopilotClient, approveAll, defineTool, RuntimeConnection } from "../../src/index.js";
import { createSdkTestContext, isCI } from "./harness/sdkTestContext.js";
import { createSdkTestContext, DEFAULT_GITHUB_TOKEN, isCI } from "./harness/sdkTestContext.js";
import { getFinalAssistantMessage, getNextEventOfType, retry } from "./harness/sdkTestHelper.js";

describe("Sessions", async () => {
Expand Down Expand Up @@ -464,6 +464,40 @@ describe("Sessions", async () => {
expect(session2.sessionId).toBe(sessionId);
});

it("resumes a persisted session from a new client when an MCP OAuth handler is configured", async () => {
// Take a turn so the session is persisted to the store and can be
// loaded by a different CLI process.
const session1 = await client.createSession({
onPermissionRequest: approveAll,
onMcpAuthRequest: () => ({ kind: "cancelled" }),
});
const sessionId = session1.sessionId;
const answer = await session1.sendAndWait({ prompt: "What is 1+1?" });
expect(answer?.data.content).toContain("2");

// Resume from a fresh client (new CLI process). Its routing table does
// not know the session until it handles `session.resume`. Because an MCP
// OAuth handler is configured, the SDK issues a session-scoped
// `session.eventLog.registerInterest` for `mcp.oauth_required`; that must
// be sent AFTER `session.resume`, otherwise the runtime rejects it with
// "Session not found: <id>".
const newClient = new CopilotClient({
env,
gitHubToken: isCI
? DEFAULT_GITHUB_TOKEN
: (process.env.GITHUB_TOKEN ?? DEFAULT_GITHUB_TOKEN),
});
onTestFinished(() => newClient.forceStop());

const session2 = await newClient.resumeSession(sessionId, {
onPermissionRequest: approveAll,
onMcpAuthRequest: () => ({ kind: "cancelled" }),
});

expect(session2.sessionId).toBe(sessionId);
await session2.disconnect();
});

it("should abort a session", async () => {
const session = await client.createSession({ onPermissionRequest: approveAll });

Expand Down
Loading
Loading