diff --git a/dotnet/src/BearerTokenProvider.cs b/dotnet/src/BearerTokenProvider.cs new file mode 100644 index 000000000..2c59da09b --- /dev/null +++ b/dotnet/src/BearerTokenProvider.cs @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Diagnostics.CodeAnalysis; + +namespace GitHub.Copilot; + +/// +/// Arguments passed to a bearer-token callback (the GetBearerToken property +/// on / ) when the +/// runtime needs a fresh bearer token for a BYOK provider. +/// +/// +/// Part of the experimental managed-identity / bearer-token-provider surface and +/// may change or be removed in future SDK or CLI releases. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class ProviderTokenArgs +{ + /// + /// Name of the BYOK provider needing a token. For the singular, whole-session + /// this is the implicit provider name + /// ("default"); for entries it is + /// . + /// + /// + /// The callback closes over its own token scope/audience; the runtime is + /// provider-agnostic and forwards only the provider name. + /// + public required string ProviderName { get; init; } +} diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index ef90b41bf..dff034b79 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -652,6 +652,7 @@ private CopilotSession InitializeSession( } ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); session.SetCanvasHandler(config.CanvasHandler); + session.RegisterBearerTokenProviders(BuildBearerTokenCallbacks(config)); RegisterSession(session); session.StartProcessingEvents(); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, @@ -664,6 +665,37 @@ private CopilotSession InitializeSession( return session; } + /// + /// Implicit provider name for the singular, whole-session . + /// + private const string DefaultBearerTokenProviderName = "default"; + + /// + /// Collects the per-provider GetBearerToken callbacks keyed by + /// provider name for session-side registration. The singular, whole-session + /// uses the implicit + /// . + /// + private static Dictionary>> BuildBearerTokenCallbacks(SessionConfigBase config) + { + var callbacks = new Dictionary>>(StringComparer.Ordinal); + if (config.Provider?.GetBearerToken is { } singular) + { + callbacks[DefaultBearerTokenProviderName] = singular; + } + if (config.Providers != null) + { + foreach (var provider in config.Providers) + { + if (provider.GetBearerToken is { } callback) + { + callbacks[provider.Name] = callback; + } + } + } + return callbacks; + } + /// /// Catches misuse of / /// at the SDK boundary so diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index e63b7fa59..f8e285eab 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -58,6 +58,7 @@ public sealed partial class CopilotSession : IAsyncDisposable { private readonly Dictionary _toolHandlers = []; private readonly Dictionary> _commandHandlers = []; + private readonly Dictionary>> _bearerTokenProviders = new(StringComparer.Ordinal); private readonly ILogger _logger; private readonly CopilotClient _parentClient; @@ -76,9 +77,7 @@ private sealed record EventSubscription(Type EventType, Action Han private Dictionary>>? _transformCallbacks; private readonly SemaphoreSlim _transformCallbacksLock = new(1, 1); -#pragma warning disable GHCP001 private IReadOnlyList _openCanvases = Array.Empty(); -#pragma warning restore GHCP001 private int _isDisposed; @@ -126,7 +125,6 @@ public SessionCapabilities Capabilities private set; } -#pragma warning disable GHCP001 /// /// Canvas instances currently known to be open for this session. /// @@ -136,7 +134,6 @@ public SessionCapabilities Capabilities /// [Experimental(Diagnostics.Experimental)] public IReadOnlyList OpenCanvases => _openCanvases; -#pragma warning restore GHCP001 /// /// Gets the UI API for eliciting information from the user during this session. @@ -873,6 +870,51 @@ internal void RegisterAutoModeSwitchHandler(Func + /// Registers per-provider GetBearerToken callbacks for BYOK + /// providers configured with managed-identity / on-demand bearer-token auth. + /// + /// + /// The runtime never receives the callback itself; the SDK strips it from the + /// provider config and instead sends hasBearerTokenProvider: true. When + /// the runtime needs a token it issues a session-scoped + /// providerToken.getToken request, which this handler routes to the + /// matching per-provider callback. + /// + /// Map of provider name to callback, or null/empty to clear. + internal void RegisterBearerTokenProviders(IReadOnlyDictionary>>? providers) + { + _bearerTokenProviders.Clear(); + if (providers is null || providers.Count == 0) + { + ClientSessionApis.ProviderToken = null; + return; + } + foreach (var (name, callback) in providers) + { + _bearerTokenProviders[name] = callback; + } + ClientSessionApis.ProviderToken = new BearerTokenProviderHandler(this); + } + + /// + /// Routes runtime providerToken.getToken requests to the matching + /// per-provider GetBearerToken callback registered on the session. + /// + private sealed class BearerTokenProviderHandler(CopilotSession session) : IProviderTokenHandler + { + public async Task GetTokenAsync(ProviderTokenAcquireRequest request, CancellationToken cancellationToken = default) + { + if (!session._bearerTokenProviders.TryGetValue(request.ProviderName, out var callback)) + { + throw new InvalidOperationException( + $"No bearer-token provider registered for provider \"{request.ProviderName}\""); + } + var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName }).ConfigureAwait(false); + return new ProviderTokenAcquireResult { Token = token }; + } + } + /// /// Sets the capabilities reported by the host for this session. /// @@ -882,7 +924,6 @@ internal void SetCapabilities(SessionCapabilities? capabilities) Capabilities = capabilities ?? new SessionCapabilities(); } -#pragma warning disable GHCP001 internal void SetOpenCanvases(IList? canvases) { _openCanvases = canvases is { Count: > 0 } @@ -959,7 +1000,6 @@ private static JsonElement SerializeActionResult(object? value) var element = CopilotClient.ToJsonElementForWire(value); return element ?? NullJsonElement; } -#pragma warning restore GHCP001 private sealed class CanvasHandlerAdapter(ICanvasHandler handler) : Rpc.ICanvasHandler { diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index b02aa272a..a33dff61c 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -5,12 +5,14 @@ using GitHub.Copilot.Rpc; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; +using System; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using System.Threading.Tasks; namespace GitHub.Copilot; @@ -2041,6 +2043,28 @@ public sealed class ProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// + /// Wire-only flag, emitted automatically when is set, that tells + /// the runtime to request a token over the session-scoped providerToken.getToken RPC + /// before each outbound request to this provider. Derived from ; + /// internal and never part of the public API. + /// + [JsonInclude] + [JsonPropertyName("hasBearerTokenProvider")] + internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + + /// + /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for + /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a + /// callback backed by your own identity library. Never serialized — setting it makes the SDK send + /// hasBearerTokenProvider: true on the wire and answer the runtime's + /// providerToken.getToken requests. Mutually exclusive with and + /// . + /// + [JsonIgnore] + [Experimental(Diagnostics.Experimental)] + public Func>? GetBearerToken { get; set; } + /// /// Azure-specific configuration options. /// @@ -2173,6 +2197,28 @@ public sealed class NamedProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// + /// Wire-only flag, emitted automatically when is set, that tells + /// the runtime to request a token over the session-scoped providerToken.getToken RPC + /// before each outbound request to this provider. Derived from ; + /// internal and never part of the public API. + /// + [JsonInclude] + [JsonPropertyName("hasBearerTokenProvider")] + internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + + /// + /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for + /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a + /// callback backed by your own identity library. Never serialized — setting it makes the SDK send + /// hasBearerTokenProvider: true on the wire and answer the runtime's + /// providerToken.getToken requests. Mutually exclusive with and + /// . + /// + [JsonIgnore] + [Experimental(Diagnostics.Experimental)] + public Func>? GetBearerToken { get; set; } + /// /// Azure-specific configuration options. /// diff --git a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs new file mode 100644 index 000000000..3f869a437 --- /dev/null +++ b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs @@ -0,0 +1,287 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Collections.Concurrent; +using System.Net; +using System.Net.Http; +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +/// +/// End-to-end coverage for the experimental BYOK bearer-token-provider surface +/// (GetBearerToken on a provider config). The callback stays entirely on +/// the SDK/client side: the SDK strips it from the wire config, sets the +/// hasBearerTokenProvider flag, and the runtime calls back over the +/// session-scoped providerToken.getToken RPC before each outbound model +/// request, applying the returned token as the Authorization header. +/// +/// +/// +/// These tests mirror the Node SDK's byok_bearer_token_provider.e2e.test.ts. +/// Rather than standing up a real HTTP listener, each test installs a +/// that intercepts the runtime's outbound +/// model request in-process, captures the Authorization header, and +/// returns a synthetic response — so nothing touches the network and there is no +/// CAPI proxy acting as the inference endpoint. They validate, against a real +/// runtime: +/// +/// +/// the callback's token reaches the model request as Authorization: Bearer <token>; +/// the runtime re-acquires a token per request (no runtime-side caching); +/// per-provider dispatch routes each provider's turn to its own callback, +/// and the resulting token reaches that provider's endpoint. +/// +/// +public class ByokBearerTokenProviderE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "byok_bearer_token_provider", output) +{ + // Fake BYOK provider hosts. These are never actually dialed: the request + // handler fully answers any request aimed at a `.invalid` host, so they only + // need to be syntactically valid, non-resolving URLs. Distinct hosts let the + // per-provider test assert routing by host. + private const string PrimaryHost = "byok-endpoint.invalid"; + private const string PrimaryBaseUrl = $"https://{PrimaryHost}/v1"; + private const string RedHost = "byok-red.invalid"; + private const string RedBaseUrl = $"https://{RedHost}/v1"; + private const string BlueHost = "byok-blue.invalid"; + private const string BlueBaseUrl = $"https://{BlueHost}/v1"; + + private CopilotClient CreateClientWith(CapturingRequestHandler handler) => + Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + RequestHandler = handler, + }); + + /// + /// Drives one BYOK turn against the given providers/models. The capturing + /// handler 404s the BYOK request, which errors the turn after the runtime has + /// already applied the (token-bearing) Authorization header — which is + /// all these tests assert on. The resulting error is swallowed. + /// + private static async Task RunTurnAsync( + CopilotClient client, + IList providers, + IList models, + string selectionId, + string prompt) + { + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Model = selectionId, + Providers = providers, + Models = models, + }); + try + { + await session.SendAndWaitAsync(new MessageOptions { Prompt = prompt }); + } + catch + { + // The handler always 404s the BYOK endpoint, so the turn errors after + // the token-bearing request was already captured. Expected. + } + finally + { + await session.DisposeAsync(); + } + } + + [Fact] + public async Task Applies_The_Callbacks_Token_As_The_Authorization_Header() + { + const string sentinel = "sentinel-bearer-token-abc123"; + var calls = 0; + + var handler = new CapturingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var providers = new List + { + new() + { + Name = "mi", + Type = "openai", + WireApi = "completions", + BaseUrl = PrimaryBaseUrl, + GetBearerToken = _ => + { + Interlocked.Increment(ref calls); + return Task.FromResult(sentinel); + }, + }, + }; + var models = new List + { + new() { Id = "default", Provider = "mi", WireModel = "byok-gpt-4o" }, + }; + + await RunTurnAsync(client, providers, models, "mi/default", "What is 5+5?"); + + // The runtime acquired a token via the callback and applied it verbatim as + // the bearer credential on the outbound model request. + Assert.Contains($"Bearer {sentinel}", handler.AuthHeaders()); + Assert.True(calls >= 1, "Expected the bearer-token callback to be invoked at least once."); + } + + [Fact] + public async Task Re_Acquires_A_Fresh_Token_For_Each_Request() + { + var calls = 0; + + var handler = new CapturingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var providers = new List + { + new() + { + Name = "mi", + Type = "openai", + WireApi = "completions", + BaseUrl = PrimaryBaseUrl, + // A distinct token per acquisition proves the runtime re-invokes + // the callback per request rather than caching a previous token. + GetBearerToken = _ => + { + var n = Interlocked.Increment(ref calls); + return Task.FromResult($"rotating-token-{n}"); + }, + }, + }; + var models = new List + { + new() { Id = "default", Provider = "mi", WireModel = "byok-gpt-4o" }, + }; + + await RunTurnAsync(client, providers, models, "mi/default", "What is 1+1?"); + await RunTurnAsync(client, providers, models, "mi/default", "What is 2+2?"); + + // Each outbound request carries a freshly-acquired, distinct token. + var auths = handler.AuthHeaders(); + Assert.True(auths.Count >= 2, $"Expected at least 2 captured Authorization headers, saw {auths.Count}."); + Assert.Matches(@"^Bearer rotating-token-\d+$", auths[0]); + Assert.Matches(@"^Bearer rotating-token-\d+$", auths[1]); + Assert.NotEqual(auths[0], auths[1]); + Assert.True(calls >= 2, "Expected the bearer-token callback to be invoked at least twice."); + } + + [Fact] + public async Task Dispatches_Token_Acquisition_Per_Provider() + { + var tokenByProvider = new Dictionary + { + ["red"] = "token-for-red", + ["blue"] = "token-for-blue", + }; + var acquiredFor = new ConcurrentBag(); + + Func> MakeCallback(string providerName) => + args => + { + // The runtime forwards the requesting provider's name so the client + // can dispatch to the right credential. + Assert.Equal(providerName, args.ProviderName); + acquiredFor.Add(providerName); + return Task.FromResult(tokenByProvider[providerName]); + }; + + var handler = new CapturingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var providers = new List + { + new() + { + Name = "red", + Type = "openai", + WireApi = "completions", + BaseUrl = RedBaseUrl, + GetBearerToken = MakeCallback("red"), + }, + new() + { + Name = "blue", + Type = "openai", + WireApi = "completions", + BaseUrl = BlueBaseUrl, + GetBearerToken = MakeCallback("blue"), + }, + }; + var models = new List + { + new() { Id = "default", Provider = "red", WireModel = "byok-gpt-4o" }, + new() { Id = "default", Provider = "blue", WireModel = "byok-gpt-4o" }, + }; + + await RunTurnAsync(client, providers, models, "red/default", "What is 3+3?"); + await RunTurnAsync(client, providers, models, "blue/default", "What is 4+4?"); + + // Each provider's turn was authenticated with its own token AND that token + // was delivered to that provider's endpoint, proving per-provider dispatch + // (not a single session-global credential). + Assert.Equal($"Bearer {tokenByProvider["red"]}", handler.AuthHeaderForHost(RedHost)); + Assert.Equal($"Bearer {tokenByProvider["blue"]}", handler.AuthHeaderForHost(BlueHost)); + Assert.Contains("red", acquiredFor); + Assert.Contains("blue", acquiredFor); + } +} + +/// +/// A used in place of a real HTTP listener. +/// The runtime invokes for every model-layer HTTP +/// request. Requests aimed at a fake BYOK host (*.invalid) are captured — +/// recording the Authorization header the runtime applied after calling +/// the provider's GetBearerToken callback over the session-scoped +/// providerToken.getToken RPC — and answered with a synthetic 404 +/// (a non-retryable status, so each outbound model request yields exactly one +/// capture). Every other request (CAPI bootstrap: model catalog, policy, …) is +/// served a synthetic well-formed response so the bootstrap never touches the +/// network. +/// +internal sealed class CapturingRequestHandler : CopilotRequestHandler +{ + private readonly ConcurrentQueue _captures = new(); + + protected override Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) + { + var uri = request.RequestUri!; + if (uri.Host.EndsWith(".invalid", StringComparison.Ordinal)) + { + _captures.Enqueue(new CapturedRequest( + uri.Host, + request.Headers.TryGetValues("Authorization", out var values) + ? string.Join(", ", values) + : null)); + + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent( + "{\"error\":{\"message\":\"fake byok endpoint\"}}", + System.Text.Encoding.UTF8, + "application/json"), + }); + } + + // CAPI bootstrap (model catalog, policy, …) — answered off-network. + return Task.FromResult(RecordingRequestHandler.BuildNonInferenceResponse(uri.ToString())); + } + + /// The Authorization headers captured across BYOK requests, in arrival order. + public IReadOnlyList AuthHeaders() => + [.. _captures.Select(c => c.Authorization).Where(v => v is not null).Cast()]; + + /// The Authorization header captured for requests aimed at , if any. + public string? AuthHeaderForHost(string host) => + _captures.FirstOrDefault(c => string.Equals(c.Host, host, StringComparison.Ordinal))?.Authorization; + + private sealed record CapturedRequest(string Host, string? Authorization); +} diff --git a/go/client.go b/go/client.go index 6e4b28516..7cdb4fbad 100644 --- a/go/client.go +++ b/go/client.go @@ -53,6 +53,30 @@ import ( "github.com/github/copilot-sdk/go/rpc" ) +// defaultBearerTokenProviderName is the implicit provider name for the singular, +// whole-session [ProviderConfig]. Named providers are keyed by their own Name. +const defaultBearerTokenProviderName = "default" + +// collectBearerTokenProviders gathers the per-provider [GetBearerToken] callbacks +// from the singular provider and any named providers, keyed by provider name. The +// singular provider uses the implicit name "default"; named providers use their +// own Name. Returns nil when no callbacks are configured. +func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]GetBearerToken { + callbacks := make(map[string]GetBearerToken) + if provider != nil && provider.GetBearerToken != nil { + callbacks[defaultBearerTokenProviderName] = provider.GetBearerToken + } + for i := range providers { + if providers[i].GetBearerToken != nil { + callbacks[providers[i].Name] = providers[i].GetBearerToken + } + } + if len(callbacks) == 0 { + return nil + } + return callbacks +} + func validateSessionFSConfig(config *SessionFSConfig) error { if config == nil { return nil @@ -809,6 +833,9 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses if config.CanvasHandler != nil { s.registerCanvasHandler(config.CanvasHandler) } + if bearerTokenProviders := collectBearerTokenProviders(config.Provider, config.Providers); bearerTokenProviders != nil { + s.registerBearerTokenProviders(bearerTokenProviders) + } c.sessionsMux.Lock() c.sessions[sessionID] = s @@ -1106,6 +1133,9 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.CanvasHandler != nil { session.registerCanvasHandler(config.CanvasHandler) } + if bearerTokenProviders := collectBearerTokenProviders(config.Provider, config.Providers); bearerTokenProviders != nil { + session.registerBearerTokenProviders(bearerTokenProviders) + } c.sessionsMux.Lock() c.sessions[sessionID] = session diff --git a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go new file mode 100644 index 000000000..6a6e5cbc2 --- /dev/null +++ b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go @@ -0,0 +1,284 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "net/http" + "strconv" + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// Fake BYOK provider base URLs. These hosts are never actually dialed: the +// capturing RoundTripper fully answers any request aimed at a `.invalid` host, +// so they only need to be syntactically valid, non-resolving URLs. Distinct +// hosts let the per-provider test assert routing by host. +const ( + byokPrimaryHost = "byok-endpoint.invalid" + byokPrimaryBaseURL = "https://" + byokPrimaryHost + "/v1" + byokRedHost = "byok-red.invalid" + byokRedBaseURL = "https://" + byokRedHost + "/v1" + byokBlueHost = "byok-blue.invalid" + byokBlueBaseURL = "https://" + byokBlueHost + "/v1" +) + +// capturedBYOKRequest records the host and Authorization header of one outbound +// HTTP request the runtime aimed at a fake BYOK provider endpoint. +type capturedBYOKRequest struct { + host string + authorization string +} + +// byokCapturingRoundTripper stands in for a real HTTP upstream. It records the +// `Authorization` header the runtime applied (after calling the provider's +// GetBearerToken callback over the session-scoped `providerToken.getToken` RPC) +// for every request aimed at a fake `.invalid` BYOK host, answering them with a +// synthetic 404 (a non-retryable status, so each outbound model request yields +// exactly one capture). Every other request (CAPI bootstrap: model catalog, +// policy, session) is fabricated locally so the test never touches the network. +type byokCapturingRoundTripper struct { + mu sync.Mutex + captures []capturedBYOKRequest +} + +func (rt *byokCapturingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasSuffix(req.URL.Hostname(), ".invalid") { + rt.mu.Lock() + rt.captures = append(rt.captures, capturedBYOKRequest{ + host: req.URL.Host, + authorization: req.Header.Get("Authorization"), + }) + rt.mu.Unlock() + if req.Body != nil { + _ = req.Body.Close() + } + return buildJSONResponse(http.StatusNotFound, `{"error":{"message":"fake byok endpoint"}}`), nil + } + return buildNonInferenceResponse(req.URL.String()), nil +} + +// authHeaders returns the captured Authorization headers in arrival order. +func (rt *byokCapturingRoundTripper) authHeaders() []string { + rt.mu.Lock() + defer rt.mu.Unlock() + headers := make([]string, 0, len(rt.captures)) + for _, c := range rt.captures { + if c.authorization != "" { + headers = append(headers, c.authorization) + } + } + return headers +} + +// authHeaderForHost returns the Authorization header captured for requests aimed +// at host, if any. +func (rt *byokCapturingRoundTripper) authHeaderForHost(host string) string { + rt.mu.Lock() + defer rt.mu.Unlock() + for _, c := range rt.captures { + if c.host == host { + return c.authorization + } + } + return "" +} + +func (rt *byokCapturingRoundTripper) reset() { + rt.mu.Lock() + defer rt.mu.Unlock() + rt.captures = nil +} + +// TestBYOKBearerTokenProvider is end-to-end coverage for the experimental BYOK +// bearer-token-provider surface (GetBearerToken on a provider config). The +// callback stays entirely on the SDK/client side: the SDK strips it from the +// wire config, sets the `hasBearerTokenProvider` flag, and the runtime calls +// back over the session-scoped `providerToken.getToken` RPC before each outbound +// model request, applying the returned token as the `Authorization` header. +// +// Rather than standing up a real HTTP listener, the test installs a capturing +// RoundTripper that intercepts the runtime's outbound model request in-process, +// captures the `Authorization` header, and returns a synthetic response. It +// validates, against a real runtime: +// 1. the callback's token reaches the model request as `Authorization: Bearer `; +// 2. the runtime re-acquires a token per request (no runtime-side caching); +// 3. per-provider dispatch routes each provider's turn to its own callback, and +// the resulting token reaches that provider's endpoint. +func TestBYOKBearerTokenProvider(t *testing.T) { + ctx := testharness.NewTestContext(t) + rt := &byokCapturingRoundTripper{} + handler := &copilot.CopilotRequestHandler{Transport: rt} + + client := newCopilotRequestClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // runTurn drives one BYOK turn; the synthetic 404 errors the turn after the + // runtime has already sent the token-bearing request, which is all the test + // asserts on, so the resulting error is expected and swallowed. + runTurn := func(providers []copilot.NamedProviderConfig, models []copilot.ProviderModelConfig, selectionID, prompt string) { + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Model: selectionID, + Providers: providers, + Models: models, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + _, _ = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: prompt}) + _ = session.Disconnect() + } + + t.Run("applies the callback's token as the Authorization header", func(t *testing.T) { + rt.reset() + const sentinel = "sentinel-bearer-token-abc123" + var mu sync.Mutex + calls := 0 + getBearerToken := func(args copilot.ProviderTokenArgs) (string, error) { + mu.Lock() + calls++ + mu.Unlock() + return sentinel, nil + } + + providers := []copilot.NamedProviderConfig{{ + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + GetBearerToken: getBearerToken, + }} + models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} + + runTurn(providers, models, "mi/default", "What is 5+5?") + + // The runtime acquired a token via the callback and applied it verbatim + // as the bearer credential on the outbound model request. + if !containsString(rt.authHeaders(), "Bearer "+sentinel) { + t.Fatalf("Expected captured Authorization headers to contain %q, got %v", "Bearer "+sentinel, rt.authHeaders()) + } + mu.Lock() + gotCalls := calls + mu.Unlock() + if gotCalls < 1 { + t.Fatalf("Expected the callback to be invoked at least once, got %d", gotCalls) + } + }) + + t.Run("re-acquires a fresh token for each request (no runtime caching)", func(t *testing.T) { + rt.reset() + var mu sync.Mutex + calls := 0 + getBearerToken := func(args copilot.ProviderTokenArgs) (string, error) { + mu.Lock() + calls++ + token := "rotating-token-" + strconv.Itoa(calls) + mu.Unlock() + // A distinct token per acquisition proves the runtime re-invokes the + // callback per request rather than caching a previous token. + return token, nil + } + + providers := []copilot.NamedProviderConfig{{ + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + GetBearerToken: getBearerToken, + }} + models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} + + runTurn(providers, models, "mi/default", "What is 1+1?") + runTurn(providers, models, "mi/default", "What is 2+2?") + + // Each outbound request carries a freshly-acquired, distinct token. + auths := rt.authHeaders() + if len(auths) < 2 { + t.Fatalf("Expected at least 2 captured Authorization headers, got %d: %v", len(auths), auths) + } + if !strings.HasPrefix(auths[0], "Bearer rotating-token-") || !strings.HasPrefix(auths[1], "Bearer rotating-token-") { + t.Fatalf("Expected rotating-token bearer headers, got %v", auths) + } + if auths[0] == auths[1] { + t.Fatalf("Expected distinct tokens per request, both were %q", auths[0]) + } + mu.Lock() + gotCalls := calls + mu.Unlock() + if gotCalls < 2 { + t.Fatalf("Expected the callback to be invoked at least twice, got %d", gotCalls) + } + }) + + t.Run("dispatches token acquisition per provider", func(t *testing.T) { + rt.reset() + tokenByProvider := map[string]string{ + "red": "token-for-red", + "blue": "token-for-blue", + } + var mu sync.Mutex + var acquiredFor []string + makeCallback := func(providerName string) copilot.GetBearerToken { + return func(args copilot.ProviderTokenArgs) (string, error) { + // The runtime forwards the requesting provider's name so the + // client can dispatch to the right credential. + if args.ProviderName != providerName { + t.Errorf("Expected providerName %q, got %q", providerName, args.ProviderName) + } + mu.Lock() + acquiredFor = append(acquiredFor, providerName) + mu.Unlock() + return tokenByProvider[providerName], nil + } + } + + providers := []copilot.NamedProviderConfig{ + { + Name: "red", + Type: "openai", + WireAPI: "completions", + BaseURL: byokRedBaseURL, + GetBearerToken: makeCallback("red"), + }, + { + Name: "blue", + Type: "openai", + WireAPI: "completions", + BaseURL: byokBlueBaseURL, + GetBearerToken: makeCallback("blue"), + }, + } + models := []copilot.ProviderModelConfig{ + {ID: "default", Provider: "red", WireModel: "byok-gpt-4o"}, + {ID: "default", Provider: "blue", WireModel: "byok-gpt-4o"}, + } + + runTurn(providers, models, "red/default", "What is 3+3?") + runTurn(providers, models, "blue/default", "What is 4+4?") + + // Each provider's turn was authenticated with its own token AND that + // token was delivered to that provider's endpoint, proving per-provider + // dispatch (not a single session-global credential). + if got := rt.authHeaderForHost(byokRedHost); got != "Bearer "+tokenByProvider["red"] { + t.Fatalf("Expected red host to receive %q, got %q", "Bearer "+tokenByProvider["red"], got) + } + if got := rt.authHeaderForHost(byokBlueHost); got != "Bearer "+tokenByProvider["blue"] { + t.Fatalf("Expected blue host to receive %q, got %q", "Bearer "+tokenByProvider["blue"], got) + } + mu.Lock() + got := append([]string(nil), acquiredFor...) + mu.Unlock() + if !containsString(got, "red") || !containsString(got, "blue") { + t.Fatalf("Expected both providers to acquire tokens, got %v", got) + } + }) +} diff --git a/go/session.go b/go/session.go index 1aeb70c24..d92466d8e 100644 --- a/go/session.go +++ b/go/session.go @@ -77,6 +77,8 @@ type Session struct { elicitationMu sync.RWMutex canvasHandler CanvasHandler canvasMu sync.RWMutex + bearerTokenProviders map[string]GetBearerToken + bearerTokenMu sync.RWMutex openCanvases []rpc.OpenCanvasInstance openCanvasesMu sync.RWMutex capabilities SessionCapabilities @@ -181,6 +183,66 @@ func (s *Session) getCanvasHandler() CanvasHandler { return s.canvasHandler } +// registerBearerTokenProviders installs per-provider [GetBearerToken] callbacks +// for BYOK providers configured with managed-identity / on-demand bearer-token +// auth, keyed by provider name. +// +// The runtime never receives the callback itself; the SDK strips it from the +// provider config and instead sends `hasBearerTokenProvider: true`. When the +// runtime needs a token it issues a session-scoped `providerToken.getToken` +// request, which the session's provider-token adapter routes to the matching +// per-provider callback. +func (s *Session) registerBearerTokenProviders(providers map[string]GetBearerToken) { + s.bearerTokenMu.Lock() + defer s.bearerTokenMu.Unlock() + s.bearerTokenProviders = make(map[string]GetBearerToken, len(providers)) + for name, callback := range providers { + if callback == nil { + continue + } + s.bearerTokenProviders[name] = callback + } +} + +func (s *Session) getBearerTokenProvider(providerName string) GetBearerToken { + s.bearerTokenMu.RLock() + defer s.bearerTokenMu.RUnlock() + return s.bearerTokenProviders[providerName] +} + +type providerTokenClientSessionAdapter struct { + session *Session +} + +func newProviderTokenClientSessionAdapter(session *Session) rpc.ProviderTokenHandler { + return &providerTokenClientSessionAdapter{session: session} +} + +func (a *providerTokenClientSessionAdapter) GetToken(request *rpc.ProviderTokenAcquireRequest) (*rpc.ProviderTokenAcquireResult, error) { + if request == nil { + return nil, providerTokenJSONRPCError("missing provider token request") + } + if a.session == nil || a.session.SessionID != request.SessionID { + return nil, providerTokenJSONRPCError(fmt.Sprintf("unknown session %s", request.SessionID)) + } + callback := a.session.getBearerTokenProvider(request.ProviderName) + if callback == nil { + return nil, providerTokenJSONRPCError(fmt.Sprintf("No bearer-token provider registered for provider %q", request.ProviderName)) + } + token, err := callback(ProviderTokenArgs{ProviderName: request.ProviderName}) + if err != nil { + return nil, providerTokenJSONRPCError(err.Error()) + } + return &rpc.ProviderTokenAcquireResult{Token: token}, nil +} + +func providerTokenJSONRPCError(message string) *jsonrpc2.Error { + return &jsonrpc2.Error{ + Code: -32603, + Message: message, + } +} + type canvasClientSessionAdapter struct { session *Session } @@ -307,6 +369,7 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) RPC: rpc.NewSessionRPC(client, sessionID), } s.clientSessionAPIs.Canvas = newCanvasClientSessionAdapter(s) + s.clientSessionAPIs.ProviderToken = newProviderTokenClientSessionAdapter(s) go s.processEvents() return s } diff --git a/go/types.go b/go/types.go index ac3b138a5..52a06e110 100644 --- a/go/types.go +++ b/go/types.go @@ -1563,6 +1563,37 @@ type ResumeSessionConfig struct { // general external use. ExpAssignments any } + +// ProviderTokenArgs carries the context passed to a [GetBearerToken] callback +// when the runtime needs a fresh bearer token for a BYOK provider. +// +// Experimental: ProviderTokenArgs is part of the experimental managed-identity / +// bearer-token-provider surface and may change or be removed in future SDK or CLI +// releases. +type ProviderTokenArgs struct { + // ProviderName is the name of the BYOK provider needing a token. For the + // singular, whole-session [ProviderConfig] this is the implicit provider name + // ("default"); for [NamedProviderConfig] entries it is + // [NamedProviderConfig.Name]. + // + // The callback closes over its own token scope/audience; the runtime is + // provider-agnostic and forwards only the provider name. + ProviderName string +} + +// GetBearerToken is a per-provider callback that resolves a bearer token on +// demand, returning the raw token string (without the "Bearer " prefix). The +// Copilot SDK itself takes no Azure dependency: the consumer supplies this +// callback backed by their own identity library (for example azidentity's +// DefaultAzureCredential.GetToken), and the runtime calls it once before each +// outbound model request. The runtime does no caching of its own, so the callback +// (or the identity library it wraps) owns token caching and refresh. +// +// Experimental: GetBearerToken is part of the experimental managed-identity / +// bearer-token-provider surface and may change or be removed in future SDK or CLI +// releases. +type GetBearerToken func(args ProviderTokenArgs) (string, error) + type ProviderConfig struct { // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". Type string `json:"type,omitempty"` @@ -1603,6 +1634,33 @@ type ProviderConfig struct { // tokens. When hit, the model stops generating and returns a truncated // response. MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + // GetBearerToken resolves a bearer token on demand for this provider + // (managed-identity / on-demand auth). When set, the SDK strips the callback + // from the wire config and instead sends `hasBearerTokenProvider: true`; the + // runtime calls back over the session-scoped `providerToken.getToken` RPC + // before each outbound model request and applies the returned token as the + // Authorization header. Never serialized. + // + // Experimental: part of the experimental managed-identity / bearer-token-provider + // surface and may change or be removed in future SDK or CLI releases. + GetBearerToken GetBearerToken `json:"-"` +} + +// MarshalJSON serializes the provider config, deriving the wire-only +// `hasBearerTokenProvider` flag from the presence of [ProviderConfig.GetBearerToken]. +// The non-serializable callback never crosses the RPC boundary; the runtime only +// learns that a token provider exists and forwards the provider name back when it +// needs a token. +func (p ProviderConfig) MarshalJSON() ([]byte, error) { + type wire ProviderConfig + aux := struct { + wire + HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` + }{wire: wire(p)} + if p.GetBearerToken != nil { + aux.HasBearerTokenProvider = Bool(true) + } + return json.Marshal(aux) } // CapiSessionOptions configures provider-scoped Copilot API (CAPI) session behavior. @@ -1657,6 +1715,33 @@ type NamedProviderConfig struct { Azure *AzureProviderOptions `json:"azure,omitempty"` // Headers are custom HTTP headers included in all outbound provider requests. Headers map[string]string `json:"headers,omitempty"` + // GetBearerToken resolves a bearer token on demand for this provider + // (managed-identity / on-demand auth). When set, the SDK strips the callback + // from the wire config and instead sends `hasBearerTokenProvider: true`; the + // runtime calls back over the session-scoped `providerToken.getToken` RPC + // before each outbound model request and applies the returned token as the + // Authorization header. Never serialized. + // + // Experimental: part of the experimental managed-identity / bearer-token-provider + // surface and may change or be removed in future SDK or CLI releases. + GetBearerToken GetBearerToken `json:"-"` +} + +// MarshalJSON serializes the named provider config, deriving the wire-only +// `hasBearerTokenProvider` flag from the presence of +// [NamedProviderConfig.GetBearerToken]. The non-serializable callback never +// crosses the RPC boundary; the runtime only learns that a token provider exists +// and forwards the provider name back when it needs a token. +func (p NamedProviderConfig) MarshalJSON() ([]byte, error) { + type wire NamedProviderConfig + aux := struct { + wire + HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` + }{wire: wire(p)} + if p.GetBearerToken != nil { + aux.HasBearerTokenProvider = Bool(true) + } + return json.Marshal(aux) } // ProviderModelConfig is a BYOK model definition that references a diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index 90a2662d0..9e0391594 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -74,6 +74,7 @@ import com.github.copilot.rpc.ExitPlanModeRequest; import com.github.copilot.rpc.ExitPlanModeResult; import com.github.copilot.rpc.ElicitationSchema; +import com.github.copilot.rpc.GetBearerToken; import com.github.copilot.rpc.GetMessagesResponse; import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; @@ -168,6 +169,7 @@ public final class CopilotSession implements AutoCloseable { private final Set> eventHandlers = ConcurrentHashMap.newKeySet(); private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); + private final Map bearerTokenProviders = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); @@ -1347,6 +1349,33 @@ void registerElicitationHandler(ElicitationHandler handler) { elicitationHandler.set(handler); } + /** + * Registers bearer-token provider callbacks for this session. + *

+ * Called internally when creating or resuming a session with BYOK providers + * that use managed-identity token callbacks. + * + * @param providers + * the callbacks keyed by provider name + */ + void registerBearerTokenProviders(Map providers) { + bearerTokenProviders.clear(); + if (providers != null) { + bearerTokenProviders.putAll(providers); + } + } + + /** + * Gets the bearer-token provider callback for the given provider name. + * + * @param providerName + * the provider name + * @return the registered callback, or {@code null} if none is registered + */ + GetBearerToken getBearerTokenProvider(String providerName) { + return bearerTokenProviders.get(providerName); + } + /** * Registers an exit-plan-mode handler for this session. *

diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java index 391f270db..b62e8c582 100644 --- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java @@ -19,6 +19,8 @@ import com.github.copilot.generated.SessionEvent; import com.github.copilot.rpc.AutoModeSwitchRequest; import com.github.copilot.rpc.ExitPlanModeRequest; +import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.ProviderTokenArgs; import com.github.copilot.rpc.PermissionRequestResult; import com.github.copilot.rpc.PermissionRequestResultKind; import com.github.copilot.rpc.SessionLifecycleEvent; @@ -88,6 +90,8 @@ void registerHandlers(JsonRpcClient rpc) { rpc.registerMethodHandler("hooks.invoke", (requestId, params) -> handleHooksInvoke(rpc, requestId, params)); rpc.registerMethodHandler("systemMessage.transform", (requestId, params) -> handleSystemMessageTransform(rpc, requestId, params)); + rpc.registerMethodHandler("providerToken.getToken", + (requestId, params) -> handleProviderTokenGetToken(rpc, requestId, params)); } private void handleSessionEvent(JsonNode params) { @@ -300,6 +304,68 @@ private void handleUserInputRequest(JsonRpcClient rpc, String requestId, JsonNod }); } + private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, JsonNode params) { + LOG.fine("Received providerToken.getToken: " + params); + runAsync(() -> { + final long requestIdLong = parseRequestId(requestId, "providerToken.getToken"); + if (requestIdLong == -1) { + return; + } + try { + String sessionId = params.get("sessionId").asText(); + String providerName = params.get("providerName").asText(); + + CopilotSession session = sessions.get(sessionId); + if (session == null) { + rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); + return; + } + + GetBearerToken provider = session.getBearerTokenProvider(providerName); + if (provider == null) { + rpc.sendErrorResponse(requestIdLong, -32603, + "No bearer-token provider registered for provider " + providerName); + return; + } + + CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName)); + if (tokenFuture == null) { + rpc.sendErrorResponse(requestIdLong, -32603, + "Bearer-token provider returned null future for provider " + providerName); + return; + } + + tokenFuture.thenAccept(token -> { + try { + if (token == null) { + rpc.sendErrorResponse(requestIdLong, -32603, + "Bearer-token provider returned null token for provider " + providerName); + return; + } + rpc.sendResponse(requestIdLong, Map.of("token", token)); + } catch (IOException e) { + LOG.log(Level.SEVERE, "Error sending provider token response", e); + } + }).exceptionally(ex -> { + LOG.log(Level.WARNING, "Bearer-token provider exception", ex); + try { + rpc.sendErrorResponse(requestIdLong, -32603, "Bearer-token provider error: " + ex.getMessage()); + } catch (IOException e) { + LOG.log(Level.SEVERE, "Error sending provider token error", e); + } + return null; + }); + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error handling providerToken.getToken", e); + try { + rpc.sendErrorResponse(requestIdLong, -32603, "Provider token handler error: " + e.getMessage()); + } catch (IOException ioException) { + LOG.log(Level.SEVERE, "Error sending provider token handler error", ioException); + } + } + }); + } + private void handleExitPlanModeRequest(JsonRpcClient rpc, String requestId, JsonNode params) { runAsync(() -> { final long requestIdLong = parseRequestId(requestId, "exitPlanMode.request"); diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index 072cf480d..943894d28 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -5,11 +5,15 @@ package com.github.copilot; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Function; import com.github.copilot.rpc.CreateSessionRequest; +import com.github.copilot.rpc.ProviderConfig; +import com.github.copilot.rpc.NamedProviderConfig; +import com.github.copilot.rpc.GetBearerToken; import com.github.copilot.rpc.CommandWireDefinition; import com.github.copilot.rpc.ResumeSessionConfig; import com.github.copilot.rpc.ResumeSessionRequest; @@ -331,6 +335,11 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + config.getProviders()); + if (!bearerTokenProviders.isEmpty()) { + session.registerBearerTokenProviders(bearerTokenProviders); + } if (config.getOnExitPlanMode() != null) { session.registerExitPlanModeHandler(config.getOnExitPlanMode()); } @@ -373,6 +382,11 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + config.getProviders()); + if (!bearerTokenProviders.isEmpty()) { + session.registerBearerTokenProviders(bearerTokenProviders); + } if (config.getOnExitPlanMode() != null) { session.registerExitPlanModeHandler(config.getOnExitPlanMode()); } @@ -383,4 +397,21 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) session.on(config.getOnEvent()); } } + + private static Map collectBearerTokenProviders(ProviderConfig provider, + List providers) { + Map bearerTokenProviders = new HashMap<>(); + if (provider != null && provider.getGetBearerToken() != null) { + bearerTokenProviders.put("default", provider.getGetBearerToken()); + } + if (providers != null) { + for (NamedProviderConfig namedProvider : providers) { + if (namedProvider != null && namedProvider.getName() != null + && namedProvider.getGetBearerToken() != null) { + bearerTokenProviders.put(namedProvider.getName(), namedProvider.getGetBearerToken()); + } + } + } + return bearerTokenProviders; + } } diff --git a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java b/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java new file mode 100644 index 000000000..27ec7f09c --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java @@ -0,0 +1,40 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.concurrent.CompletableFuture; + +import com.github.copilot.CopilotExperimental; + +/** + * Functional interface for supplying per-provider bearer tokens for BYOK + * provider requests. + *

+ * The callback returns the raw token without a {@code Bearer } prefix. The SDK + * keeps this callback client-side and the runtime requests a token via the + * session-scoped {@code providerToken.getToken} RPC before each outbound model + * request. + *

+ * Experimental. This managed-identity surface may change or be + * removed in future SDK or CLI releases. + * + * @see ProviderConfig#setGetBearerToken(GetBearerToken) + * @see NamedProviderConfig#setGetBearerToken(GetBearerToken) + * @since 1.0.0 + */ +@CopilotExperimental +@FunctionalInterface +public interface GetBearerToken { + + /** + * Gets a bearer token for the provider identified by {@code args}. + * + * @param args + * the provider token request arguments + * @return a future that completes with the raw token, without a {@code Bearer } + * prefix + */ + CompletableFuture getToken(ProviderTokenArgs args); +} diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java index dbc157739..2bdf2678f 100644 --- a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java @@ -7,6 +7,7 @@ import java.util.Collections; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +60,9 @@ public class NamedProviderConfig { @JsonProperty("bearerToken") private String bearerToken; + @JsonIgnore + private GetBearerToken getBearerToken; + @JsonProperty("azure") private AzureOptions azure; @@ -212,6 +216,39 @@ public NamedProviderConfig setBearerToken(String bearerToken) { return this; } + /** + * Gets the bearer-token provider callback. + * + * @return the bearer-token provider callback, or {@code null} if not set + */ + public GetBearerToken getGetBearerToken() { + return getBearerToken; + } + + /** + * Sets a callback that supplies bearer tokens for outbound provider requests. + *

+ * Experimental. The callback stays SDK-side and is not + * serialized. Instead, the runtime receives a {@code hasBearerTokenProvider} + * flag and calls back over the session-scoped {@code providerToken.getToken} + * RPC before each model request. Return the raw token without a {@code Bearer } + * prefix. + * + * @param getBearerToken + * the bearer-token provider callback + * @return this config for method chaining + */ + public NamedProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { + this.getBearerToken = getBearerToken; + return this; + } + + @JsonProperty("hasBearerTokenProvider") + @JsonInclude(JsonInclude.Include.NON_NULL) + Boolean hasBearerTokenProviderWireFlag() { + return getBearerToken != null ? Boolean.TRUE : null; + } + /** * Gets the Azure-specific options. * diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java index 8ba492ed9..ae59e7ead 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java @@ -56,6 +56,9 @@ public class ProviderConfig { @JsonProperty("bearerToken") private String bearerToken; + @JsonIgnore + private GetBearerToken getBearerToken; + @JsonProperty("azure") private AzureOptions azure; @@ -222,6 +225,39 @@ public ProviderConfig setBearerToken(String bearerToken) { return this; } + /** + * Gets the bearer-token provider callback. + * + * @return the bearer-token provider callback, or {@code null} if not set + */ + public GetBearerToken getGetBearerToken() { + return getBearerToken; + } + + /** + * Sets a callback that supplies bearer tokens for outbound provider requests. + *

+ * Experimental. The callback stays SDK-side and is not + * serialized. Instead, the runtime receives a {@code hasBearerTokenProvider} + * flag and calls back over the session-scoped {@code providerToken.getToken} + * RPC before each model request. Return the raw token without a {@code Bearer } + * prefix. + * + * @param getBearerToken + * the bearer-token provider callback + * @return this config for method chaining + */ + public ProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { + this.getBearerToken = getBearerToken; + return this; + } + + @JsonProperty("hasBearerTokenProvider") + @JsonInclude(JsonInclude.Include.NON_NULL) + Boolean hasBearerTokenProviderWireFlag() { + return getBearerToken != null ? Boolean.TRUE : null; + } + /** * Gets the Azure-specific options. * diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java new file mode 100644 index 000000000..3866cc0ad --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java @@ -0,0 +1,63 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import com.github.copilot.CopilotExperimental; + +/** + * Arguments passed to a BYOK bearer-token provider callback. + *

+ * Experimental. This managed-identity surface may change or be + * removed in future SDK or CLI releases. + * + * @since 1.0.0 + */ +@CopilotExperimental +public class ProviderTokenArgs { + + private String providerName; + + /** + * Creates an empty argument object. + */ + public ProviderTokenArgs() { + } + + /** + * Creates argument object for the named provider. + * + * @param providerName + * the name of the BYOK provider needing a token; {@code "default"} + * for the singular whole-session provider, otherwise the named + * provider's {@code name} + */ + public ProviderTokenArgs(String providerName) { + this.providerName = providerName; + } + + /** + * Gets the name of the BYOK provider needing a token. + *

+ * The value is {@code "default"} for the singular whole-session provider, + * otherwise the named provider's {@code name}. + * + * @return the provider name + */ + public String getProviderName() { + return providerName; + } + + /** + * Sets the name of the BYOK provider needing a token. + * + * @param providerName + * the provider name + * @return this args instance for method chaining + */ + public ProviderTokenArgs setProviderName(String providerName) { + this.providerName = providerName; + return this; + } +} diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java new file mode 100644 index 000000000..253ce136c --- /dev/null +++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java @@ -0,0 +1,274 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static com.github.copilot.CopilotRequestTestSupport.buildNonInferenceResponse; +import static com.github.copilot.CopilotRequestTestSupport.newLlmClient; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLSession; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.NamedProviderConfig; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ProviderModelConfig; +import com.github.copilot.rpc.SessionConfig; + +/** + * End-to-end coverage for the experimental BYOK bearer-token-provider surface + * ({@code getBearerToken} on a provider config). The callback stays entirely on + * the SDK/client side: the SDK keeps it off the wire, sends only the + * {@code hasBearerTokenProvider} flag, and the runtime calls back over the + * session-scoped {@code providerToken.getToken} RPC before each outbound model + * request. + */ +public class ByokBearerTokenProviderE2ETest { + + private static final String PRIMARY_HOST = "byok-endpoint.invalid"; + private static final String PRIMARY_BASE_URL = "https://" + PRIMARY_HOST + "/v1"; + private static final String RED_HOST = "byok-red.invalid"; + private static final String RED_BASE_URL = "https://" + RED_HOST + "/v1"; + private static final String BLUE_HOST = "byok-blue.invalid"; + private static final String BLUE_BASE_URL = "https://" + BLUE_HOST + "/v1"; + + private static E2ETestContext ctx; + private CapturingRequestHandler handler; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @BeforeEach + void resetHandler() { + handler = new CapturingRequestHandler(); + } + + @Test + void appliesCallbackTokenAsAuthorizationHeader() throws Exception { + String sentinel = "sentinel-bearer-token-abc123"; + AtomicInteger calls = new AtomicInteger(); + GetBearerToken getBearerToken = args -> { + calls.incrementAndGet(); + return CompletableFuture.completedFuture(sentinel); + }; + + List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + List models = List + .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); + + runTurn(providers, models, "mi/default", "What is 5+5?"); + + assertTrue(handler.authHeaders().contains("Bearer " + sentinel), + "Expected captured Authorization headers to contain the callback token: " + handler.authHeaders()); + assertTrue(calls.get() >= 1, "Expected the callback to be invoked at least once"); + } + + @Test + void reacquiresFreshTokenForEachRequest() throws Exception { + AtomicInteger calls = new AtomicInteger(); + GetBearerToken getBearerToken = args -> CompletableFuture + .completedFuture("rotating-token-" + calls.incrementAndGet()); + + List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + List models = List + .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); + + runTurn(providers, models, "mi/default", "What is 1+1?"); + runTurn(providers, models, "mi/default", "What is 2+2?"); + + List auths = handler.authHeaders(); + assertTrue(auths.size() >= 2, "Expected at least two captured Authorization headers, got " + auths); + assertTrue(auths.get(0).startsWith("Bearer rotating-token-"), "Expected rotating token, got " + auths); + assertTrue(auths.get(1).startsWith("Bearer rotating-token-"), "Expected rotating token, got " + auths); + assertNotEquals(auths.get(0), auths.get(1), "Expected distinct tokens per request"); + assertTrue(calls.get() >= 2, "Expected the callback to be invoked at least twice"); + } + + @Test + void dispatchesTokenAcquisitionPerProvider() throws Exception { + List acquiredFor = new ArrayList<>(); + GetBearerToken redCallback = args -> { + assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded"); + synchronized (acquiredFor) { + acquiredFor.add("red"); + } + return CompletableFuture.completedFuture("token-for-red"); + }; + GetBearerToken blueCallback = args -> { + assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded"); + synchronized (acquiredFor) { + acquiredFor.add("blue"); + } + return CompletableFuture.completedFuture("token-for-blue"); + }; + + List providers = List.of( + new NamedProviderConfig().setName("red").setType("openai").setWireApi("completions") + .setBaseUrl(RED_BASE_URL).setGetBearerToken(redCallback), + new NamedProviderConfig().setName("blue").setType("openai").setWireApi("completions") + .setBaseUrl(BLUE_BASE_URL).setGetBearerToken(blueCallback)); + List models = List.of( + new ProviderModelConfig().setId("default").setProvider("red").setWireModel("byok-gpt-4o"), + new ProviderModelConfig().setId("default").setProvider("blue").setWireModel("byok-gpt-4o")); + + runTurn(providers, models, "red/default", "What is 3+3?"); + runTurn(providers, models, "blue/default", "What is 4+4?"); + + assertEquals("Bearer token-for-red", handler.authHeaderForHost(RED_HOST)); + assertEquals("Bearer token-for-blue", handler.authHeaderForHost(BLUE_HOST)); + synchronized (acquiredFor) { + assertTrue(acquiredFor.contains("red"), "Expected red provider to acquire a token"); + assertTrue(acquiredFor.contains("blue"), "Expected blue provider to acquire a token"); + } + } + + private void runTurn(List providers, List models, String selectionId, + String prompt) throws Exception { + try (CopilotClient client = newLlmClient(ctx, handler)) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setModel(selectionId).setProviders(providers).setModels(models)) + .get(60, TimeUnit.SECONDS); + try { + session.sendAndWait(new MessageOptions().setPrompt(prompt)).get(60, TimeUnit.SECONDS); + } catch (Exception ignored) { + // The fake BYOK endpoint returns 404 after capturing the token-bearing request. + } finally { + try { + session.close(); + } catch (Exception ignored) { + // Ignore disconnect errors for the fake BYOK endpoint. + } + } + } + } + + private static final class CapturingRequestHandler extends CopilotRequestHandler { + + private final ConcurrentLinkedQueue captures = new ConcurrentLinkedQueue<>(); + + @Override + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext rctx) + throws Exception { + String host = request.uri().getHost(); + if (host != null && host.endsWith(".invalid")) { + captures.add(new CapturedRequest(request.uri().getHost(), + request.headers().firstValue("Authorization").orElse(null))); + return new StubHttpResponse(404, "{\"error\":{\"message\":\"fake byok endpoint\"}}"); + } + return buildNonInferenceResponse(request.uri().toString()); + } + + List authHeaders() { + List auths = new ArrayList<>(); + for (CapturedRequest capture : captures) { + if (capture.authorization() != null) { + auths.add(capture.authorization()); + } + } + return auths; + } + + String authHeaderForHost(String host) { + for (CapturedRequest capture : captures) { + if (host.equals(capture.host())) { + return capture.authorization(); + } + } + return null; + } + } + + private static final class StubHttpResponse implements HttpResponse { + + private final int status; + private final HttpHeaders headers; + private final byte[] body; + + StubHttpResponse(int status, String body) { + this.status = status; + this.body = body.getBytes(StandardCharsets.UTF_8); + this.headers = HttpHeaders.of(Map.of("content-type", List.of("application/json")), (k, v) -> true); + } + + @Override + public int statusCode() { + return status; + } + + @Override + public HttpRequest request() { + return null; + } + + @Override + public Optional> previousResponse() { + return Optional.empty(); + } + + @Override + public HttpHeaders headers() { + return headers; + } + + @Override + public InputStream body() { + return new ByteArrayInputStream(body); + } + + @Override + public Optional sslSession() { + return Optional.empty(); + } + + @Override + public URI uri() { + return null; + } + + @Override + public HttpClient.Version version() { + return HttpClient.Version.HTTP_1_1; + } + } + + private record CapturedRequest(String host, String authorization) { + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index a1b403930..e0315b211 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -51,11 +51,14 @@ import type { ExitPlanModeResult, ForegroundSessionInfo, GetAuthStatusResponse, + GetBearerToken, GetStatusResponse, InternalRuntimeConnection, LargeToolOutputConfig, MCPServerConfig, ModelInfo, + NamedProviderConfig, + ProviderConfig, ResumeSessionConfig, SectionTransformFn, SessionConfig, @@ -154,6 +157,64 @@ function toJsonSchema(parameters: Tool["parameters"]): Record | return parameters; } +/** Implicit provider name for the singular, whole-session {@link ProviderConfig}. */ +const DEFAULT_PROVIDER_NAME = "default"; + +/** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */ +type WireProviderConfig = Omit & { + hasBearerTokenProvider?: boolean; +}; + +/** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */ +type WireNamedProviderConfig = Omit & { + hasBearerTokenProvider?: boolean; +}; + +/** + * Strips the non-serializable {@link GetBearerToken} callbacks from the singular + * and named provider configs before they cross the RPC boundary, replacing each + * with a `hasBearerTokenProvider: true` wire flag. The callback closes over its + * own token scope/audience, so nothing scope-related crosses the wire — the + * runtime only forwards the provider name back when it needs a token. + * Returns wire-safe provider configs alongside a map of provider name → callback + * for session-side registration. + */ +function extractBearerTokenProviders( + provider: ProviderConfig | undefined, + providers: NamedProviderConfig[] | undefined +): { + wireProvider: WireProviderConfig | undefined; + wireProviders: WireNamedProviderConfig[] | undefined; + callbacks: Map; +} { + const callbacks = new Map(); + + let wireProvider: WireProviderConfig | undefined = provider; + if (provider?.getBearerToken) { + const { getBearerToken, ...rest } = provider; + callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken); + wireProvider = { + ...rest, + hasBearerTokenProvider: true, + }; + } + + let wireProviders: WireNamedProviderConfig[] | undefined = providers; + if (providers?.some((p) => p.getBearerToken)) { + wireProviders = providers.map((p) => { + if (!p.getBearerToken) return p; + const { getBearerToken, ...rest } = p; + callbacks.set(p.name, getBearerToken); + return { + ...rest, + hasBearerTokenProvider: true, + }; + }); + } + + return { wireProvider, wireProviders, callbacks }; +} + /** * Convert MCP server configs from public API format (workingDirectory) to * wire format (cwd) expected by the runtime. @@ -1244,6 +1305,15 @@ export class CopilotClient { const useServerGeneratedId = config.cloud != null && callerSessionId == null; const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID()); + // Strip non-serializable getBearerToken callbacks from provider configs, + // replacing them with a wire flag; keep the callbacks for session-side + // registration so the runtime can call back to acquire tokens. + const { + wireProvider: bearerWireProvider, + wireProviders: bearerWireProviders, + callbacks: bearerTokenCallbacks, + } = extractBearerTokenProviders(config.provider, config.providers); + // Extract transform callbacks from system message config before serialization. const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks( config.systemMessage @@ -1261,6 +1331,9 @@ export class CopilotClient { s.registerTools(config.tools); s.registerCanvases(config.canvases); s.registerCommands(config.commands); + if (bearerTokenCallbacks.size > 0) { + s.registerBearerTokenProviders(bearerTokenCallbacks); + } s.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { s.registerUserInputHandler(config.onUserInputRequest); @@ -1332,9 +1405,9 @@ export class CopilotClient { availableTools: toolFilterOptions.availableTools, excludedTools: toolFilterOptions.excludedTools, toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence, - provider: config.provider, + provider: bearerWireProvider, capi: config.capi, - providers: config.providers, + providers: bearerWireProviders, models: config.models, enableSessionTelemetry: config.enableSessionTelemetry, modelCapabilities: config.modelCapabilities, @@ -1454,6 +1527,14 @@ export class CopilotClient { session.registerTools(config.tools); session.registerCanvases(config.canvases); session.registerCommands(config.commands); + const { + wireProvider: bearerWireProvider, + wireProviders: bearerWireProviders, + callbacks: bearerTokenCallbacks, + } = extractBearerTokenProviders(config.provider, config.providers); + if (bearerTokenCallbacks.size > 0) { + session.registerBearerTokenProviders(bearerTokenCallbacks); + } session.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { session.registerUserInputHandler(config.onUserInputRequest); @@ -1520,9 +1601,9 @@ export class CopilotClient { name: cmd.name, description: cmd.description, })), - provider: config.provider, + provider: bearerWireProvider, capi: config.capi, - providers: config.providers, + providers: bearerWireProviders, models: config.models, modelCapabilities: config.modelCapabilities, largeOutput: toWireLargeOutput(config.largeOutput), diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 9bf02a32c..740a7bc89 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -84,6 +84,7 @@ export type { MCPHTTPServerConfig, MCPServerConfig, DefaultAgentConfig, + GetBearerToken, MessageOptions, ModelBilling, ModelBillingTokenPrices, @@ -99,6 +100,7 @@ export type { PermissionRequestResult, ProviderConfig, ProviderModelConfig, + ProviderTokenArgs, RemoteSessionMode, ResumeSessionConfig, SectionOverride, diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 83effdef7..d87d2b9de 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -26,6 +26,7 @@ import type { ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, + GetBearerToken, UiInputOptions, MessageOptions, PermissionHandler, @@ -120,6 +121,7 @@ export class CopilotSession { new Map(); private toolHandlers: Map = new Map(); private canvases: Map = new Map(); + private bearerTokenProviders: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; private userInputHandler?: UserInputHandler; @@ -795,6 +797,45 @@ export class CopilotSession { }; } + /** + * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers + * configured with managed-identity / on-demand bearer-token auth. + * + * The runtime never receives the callback itself; the SDK strips it from the + * provider config and instead sends `hasBearerTokenProvider: true`. When the + * runtime needs a token it issues a session-scoped `providerToken.getToken` + * request, which this handler routes to the matching per-provider callback. + * + * @param providers - Map of provider name → callback, or undefined/empty to clear. + * @internal This method is called internally when creating/resuming a session. + */ + registerBearerTokenProviders(providers?: Map): void { + this.bearerTokenProviders.clear(); + if (!providers || providers.size === 0) { + delete this.clientSessionApis.providerToken; + return; + } + for (const [name, callback] of providers) { + this.bearerTokenProviders.set(name, callback); + } + + const self = this; + this.clientSessionApis.providerToken = { + async getToken(params) { + const callback = self.bearerTokenProviders.get(params.providerName); + if (!callback) { + throw new Error( + `No bearer-token provider registered for provider "${params.providerName}"` + ); + } + const token = await callback({ + providerName: params.providerName, + }); + return { token }; + }, + }; + } + /** * Registers command handlers for this session. * diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 1db91df94..61d9ca06d 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2214,6 +2214,39 @@ export interface ResumeSessionConfig extends SessionConfigBase { openCanvases?: OpenCanvasInstance[]; } +/** + * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a + * fresh bearer token for a BYOK provider. + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export interface ProviderTokenArgs { + /** + * Name of the BYOK provider needing a token. For the singular, whole-session + * {@link ProviderConfig} this is the implicit provider name (`"default"`); for + * {@link NamedProviderConfig} entries it is {@link NamedProviderConfig.name}. + * + * The callback closes over its own token scope/audience; the runtime is + * provider-agnostic and forwards only the provider name. + */ + providerName: string; +} + +/** + * Per-provider callback that resolves a bearer token on demand, returning the + * raw token string (without the `Bearer ` prefix). The Copilot SDK itself takes + * no Azure dependency: the consumer supplies this callback backed by their own + * identity library (for example `@azure/identity`'s + * `DefaultAzureCredential.getToken(scope)`), and the runtime calls it once before + * each outbound model request. The runtime does no caching of its own, so the + * callback (or the identity library it wraps) owns token caching and refresh. + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export type GetBearerToken = (args: ProviderTokenArgs) => Promise; + /** * Configuration for a custom API provider. */ @@ -2256,6 +2289,18 @@ export interface ProviderConfig { */ bearerToken?: string; + /** + * Per-request bearer-token provider for managed-identity / on-demand auth. + * When set, the SDK keeps this function client-side (it is never serialized) + * and the runtime calls back into this client to acquire a token before each + * outbound request. The runtime does no caching of its own, so the callback + * owns token caching and refresh. Mutually exclusive with {@link apiKey} / + * {@link bearerToken}. + * + * @experimental + */ + getBearerToken?: GetBearerToken; + /** * Azure-specific options */ @@ -2347,6 +2392,18 @@ export interface NamedProviderConfig { */ bearerToken?: string; + /** + * Per-request bearer-token provider for managed-identity / on-demand auth. + * When set, the SDK keeps this function client-side (it is never serialized) + * and the runtime calls back into this client to acquire a token before each + * outbound request. The runtime does no caching of its own, so the callback + * owns token caching and refresh. Mutually exclusive with {@link apiKey} / + * {@link bearerToken}. + * + * @experimental + */ + getBearerToken?: GetBearerToken; + /** * Azure-specific options. */ diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts new file mode 100644 index 000000000..228b7a022 --- /dev/null +++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts @@ -0,0 +1,255 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { beforeEach, describe, expect, it } from "vitest"; +import { approveAll, CopilotRequestHandler } from "../../src/index.js"; +import type { + CopilotRequestContext, + GetBearerToken, + NamedProviderConfig, + ProviderModelConfig, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +/** + * A captured outbound HTTP request the runtime aimed at a fake BYOK provider + * endpoint: just the host and the `Authorization` header, which is all these + * tests need to assert on. + */ +interface CapturedRequest { + host: string; + authorization?: string; +} + +// Fake BYOK provider base URLs. These hosts are never actually dialed: the +// client-global request interceptor fully answers any request aimed at a +// `.invalid` host, so they only need to be syntactically valid, non-resolving +// URLs. Distinct hosts let the per-provider test assert routing by host. +const PRIMARY_HOST = "byok-endpoint.invalid"; +const PRIMARY_BASE_URL = `https://${PRIMARY_HOST}/v1`; +const RED_HOST = "byok-red.invalid"; +const RED_BASE_URL = `https://${RED_HOST}/v1`; +const BLUE_HOST = "byok-blue.invalid"; +const BLUE_BASE_URL = `https://${BLUE_HOST}/v1`; + +/** + * Client-global HTTP request interceptor (from the SDK's `CopilotRequestHandler` + * surface) used in place of a real HTTP listener. + * + * The runtime invokes {@link sendRequest} for every model-layer HTTP request it + * would otherwise issue. We capture the ones aimed at a fake BYOK host — + * recording the `Authorization` header the runtime applied after calling the + * provider's `getBearerToken` callback over the session-scoped + * `providerToken.getToken` RPC — and answer them with a synthetic `404` (a + * non-retryable status, so each outbound model request yields exactly one + * capture). Every other request (CAPI bootstrap: model catalog, policy, …) is + * passed straight through to the real network via `super.sendRequest`. + * + * Because the handler is client-global (one per CLI process), it is installed + * once for the whole fixture and {@link reset} between tests. + */ +class CapturingRequestHandler extends CopilotRequestHandler { + public readonly captures: CapturedRequest[] = []; + + protected override async sendRequest( + request: Request, + ctx: CopilotRequestContext + ): Promise { + const url = new URL(request.url); + if (url.hostname.endsWith(".invalid")) { + this.captures.push({ + host: url.host, + authorization: request.headers.get("authorization") ?? undefined, + }); + return new Response(JSON.stringify({ error: { message: "fake byok endpoint" } }), { + status: 404, + headers: { "content-type": "application/json" }, + }); + } + return super.sendRequest(request, ctx); + } + + reset(): void { + this.captures.length = 0; + } + + /** The `Authorization` headers captured across BYOK requests, in arrival order. */ + authHeaders(): string[] { + return this.captures + .map((c) => c.authorization) + .filter((v): v is string => typeof v === "string"); + } + + /** The `Authorization` header captured for requests aimed at `host`, if any. */ + authHeaderForHost(host: string): string | undefined { + return this.captures.find((c) => c.host === host)?.authorization; + } +} + +/** + * End-to-end coverage for the experimental BYOK bearer-token-provider surface + * (`getBearerToken` on a provider config). The callback stays entirely on the + * SDK/client side: the SDK strips it from the wire config, sets the + * `hasBearerTokenProvider` flag, and the runtime calls back over the session-scoped + * `providerToken.getToken` RPC before each outbound model request, applying the + * returned token as the `Authorization` header. + * + * Rather than standing up a real HTTP listener, these tests install a + * client-global {@link CapturingRequestHandler} that intercepts the runtime's + * outbound model request in-process, captures the `Authorization` header, and + * returns a synthetic response. They validate, against a real runtime: + * 1. the callback's token reaches the model request as `Authorization: Bearer `; + * 2. the runtime re-acquires a token per request (no runtime-side caching); + * 3. per-provider dispatch routes each provider's turn to its own callback, + * and the resulting token reaches that provider's endpoint. + */ +describe("BYOK bearer-token provider", async () => { + const handler = new CapturingRequestHandler(); + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { requestHandler: handler }, + }); + + beforeEach(() => { + handler.reset(); + }); + + /** Drive one BYOK turn; the synthetic 404 errors the turn, which is expected. */ + async function runTurn( + providers: NamedProviderConfig[], + models: ProviderModelConfig[], + selectionId: string, + prompt: string + ): Promise { + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: selectionId, + providers, + models, + }); + try { + // The interceptor always 404s, so the turn errors after the runtime + // has already sent the (token-bearing) request — which is all we + // assert on. Swallow the resulting error. + await session.sendAndWait({ prompt }).catch(() => undefined); + } finally { + try { + await session.disconnect(); + } catch { + // ignore disconnect errors for the fake BYOK endpoint + } + } + } + + it("applies the callback's token as the Authorization header", async () => { + const SENTINEL = "sentinel-bearer-token-abc123"; + let calls = 0; + const getBearerToken: GetBearerToken = async () => { + calls += 1; + return SENTINEL; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "mi", + type: "openai", + wireApi: "completions", + baseUrl: PRIMARY_BASE_URL, + getBearerToken, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "mi", wireModel: "byok-gpt-4o" }, + ]; + + await runTurn(providers, models, "mi/default", "What is 5+5?"); + + // The runtime acquired a token via the callback and applied it verbatim as + // the bearer credential on the outbound model request. + expect(handler.authHeaders()).toContain(`Bearer ${SENTINEL}`); + expect(calls).toBeGreaterThanOrEqual(1); + }); + + it("re-acquires a fresh token for each request (no runtime caching)", async () => { + let calls = 0; + const getBearerToken: GetBearerToken = async () => { + calls += 1; + // A distinct token per acquisition proves the runtime re-invokes the + // callback per request rather than caching a previous token. + return `rotating-token-${calls}`; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "mi", + type: "openai", + wireApi: "completions", + baseUrl: PRIMARY_BASE_URL, + getBearerToken, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "mi", wireModel: "byok-gpt-4o" }, + ]; + + await runTurn(providers, models, "mi/default", "What is 1+1?"); + await runTurn(providers, models, "mi/default", "What is 2+2?"); + + // Each outbound request carries a freshly-acquired, distinct token. + const auths = handler.authHeaders(); + expect(auths.length).toBeGreaterThanOrEqual(2); + expect(auths[0]).toMatch(/^Bearer rotating-token-\d+$/); + expect(auths[1]).toMatch(/^Bearer rotating-token-\d+$/); + expect(auths[0]).not.toBe(auths[1]); + expect(calls).toBeGreaterThanOrEqual(2); + }); + + it("dispatches token acquisition per provider", async () => { + const tokenByProvider: Record = { + red: "token-for-red", + blue: "token-for-blue", + }; + const acquiredFor: string[] = []; + const makeCallback = + (providerName: string): GetBearerToken => + async (args) => { + // The runtime forwards the requesting provider's name so the client + // can dispatch to the right credential. + expect(args.providerName).toBe(providerName); + acquiredFor.push(providerName); + return tokenByProvider[providerName]; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "red", + type: "openai", + wireApi: "completions", + baseUrl: RED_BASE_URL, + getBearerToken: makeCallback("red"), + }, + { + name: "blue", + type: "openai", + wireApi: "completions", + baseUrl: BLUE_BASE_URL, + getBearerToken: makeCallback("blue"), + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "red", wireModel: "byok-gpt-4o" }, + { id: "default", provider: "blue", wireModel: "byok-gpt-4o" }, + ]; + + await runTurn(providers, models, "red/default", "What is 3+3?"); + await runTurn(providers, models, "blue/default", "What is 4+4?"); + + // Each provider's turn was authenticated with its own token AND that token + // was delivered to that provider's endpoint, proving per-provider dispatch + // (not a single session-global credential). + expect(handler.authHeaderForHost(RED_HOST)).toBe(`Bearer ${tokenByProvider.red}`); + expect(handler.authHeaderForHost(BLUE_HOST)).toBe(`Bearer ${tokenByProvider.blue}`); + expect(acquiredFor).toContain("red"); + expect(acquiredFor).toContain("blue"); + }); +}); diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 06ecf4188..1e7a3afb1 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -100,6 +100,7 @@ ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, + GetBearerToken, InfiniteSessionConfig, InputOptions, LargeToolOutputConfig, @@ -128,6 +129,7 @@ PreToolUseHookOutput, ProviderConfig, ProviderModelConfig, + ProviderTokenArgs, ReasoningSummary, SessionCapabilities, SessionEndHandler, @@ -214,6 +216,7 @@ "ExtensionInfo", "CopilotWebSocketForwarder", "GetAuthStatusResponse", + "GetBearerToken", "GetStatusResponse", "InfiniteSessionConfig", "InputOptions", @@ -257,6 +260,7 @@ "PreToolUseHookOutput", "ProviderConfig", "ProviderModelConfig", + "ProviderTokenArgs", "ReasoningSummary", "RemoteSessionMode", "RuntimeConnection", diff --git a/python/copilot/client.py b/python/copilot/client.py index 4e32a2983..ebfdcf992 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -89,6 +89,7 @@ DefaultAgentConfig, ElicitationHandler, ExitPlanModeHandler, + GetBearerToken, InfiniteSessionConfig, LargeToolOutputConfig, MCPServerConfig, @@ -171,6 +172,36 @@ def _capi_session_options_to_wire(options: CapiSessionOptions) -> dict[str, Any] return wire +# Implicit provider name for the singular, whole-session ``provider`` config. +# Named providers are keyed by their own ``name``. +_DEFAULT_BEARER_TOKEN_PROVIDER_NAME = "default" + + +def _collect_bearer_token_callbacks( + provider: ProviderConfig | None, + providers: list[NamedProviderConfig] | None, +) -> dict[str, GetBearerToken]: + """Collect per-provider ``get_bearer_token`` callbacks keyed by provider name. + + The singular, whole-session ``provider`` uses the implicit + ``_DEFAULT_BEARER_TOKEN_PROVIDER_NAME``; ``providers`` entries use their own + ``name``. The callbacks are never serialized — the wire conversion emits + ``hasBearerTokenProvider: true`` instead and the runtime calls back over + ``providerToken.getToken``. + """ + callbacks: dict[str, GetBearerToken] = {} + if provider is not None: + singular = provider.get("get_bearer_token") + if singular is not None: + callbacks[_DEFAULT_BEARER_TOKEN_PROVIDER_NAME] = singular + if providers: + for named in providers: + callback = named.get("get_bearer_token") + if callback is not None: + callbacks[named["name"]] = callback + return callbacks + + def _validate_session_fs_config(config: SessionFsConfig) -> None: if not config.get("initial_working_directory"): raise ValueError("session_fs.initial_working_directory is required") @@ -2128,6 +2159,7 @@ def _initialize_session(sid: str) -> CopilotSession: s._register_auto_mode_switch_handler(on_auto_mode_switch_request) if canvas_handler is not None: s._register_canvas_handler(canvas_handler) + s._register_bearer_token_providers(_collect_bearer_token_callbacks(provider, providers)) if hooks: s._register_hooks(hooks) if transform_callbacks: @@ -2701,6 +2733,9 @@ async def resume_session( session._register_auto_mode_switch_handler(on_auto_mode_switch_request) if canvas_handler is not None: session._register_canvas_handler(canvas_handler) + session._register_bearer_token_providers( + _collect_bearer_token_callbacks(provider, providers) + ) if hooks: session._register_hooks(hooks) if transform_callbacks: @@ -3231,6 +3266,8 @@ def _convert_provider_to_wire_format( wire_provider["transport"] = provider["transport"] if "bearer_token" in provider: wire_provider["bearerToken"] = provider["bearer_token"] + if provider.get("get_bearer_token") is not None: + wire_provider["hasBearerTokenProvider"] = True if "headers" in provider: wire_provider["headers"] = provider["headers"] if "model_id" in provider: @@ -3267,6 +3304,8 @@ def _convert_named_provider_to_wire_format( wire["apiKey"] = provider["api_key"] if "bearer_token" in provider: wire["bearerToken"] = provider["bearer_token"] + if provider.get("get_bearer_token") is not None: + wire["hasBearerTokenProvider"] = True if "headers" in provider: wire["headers"] = provider["headers"] if "azure" in provider: diff --git a/python/copilot/session.py b/python/copilot/session.py index b4c01b885..94fba994a 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -44,6 +44,8 @@ PermissionDecisionApproveOnce, PermissionDecisionRequest, PermissionDecisionUserNotAvailable, + ProviderTokenAcquireRequest, + ProviderTokenAcquireResult, SessionLogLevel, SessionRpc, UIElicitationRequest, @@ -1077,6 +1079,29 @@ class AzureProviderOptions(TypedDict, total=False): api_version: str # Azure API version. Defaults to "2024-10-21". +class ProviderTokenArgs(TypedDict): + """Arguments passed to a :data:`GetBearerToken` callback when the runtime + needs a fresh bearer token for a BYOK provider. + + **Experimental.** Part of the bearer-token-provider surface and may change or + be removed in future SDK or CLI releases. + """ + + # Name of the BYOK provider needing a token. For the singular, whole-session + # ``provider`` this is the implicit provider name ("default"); for + # ``NamedProviderConfig`` entries it is ``NamedProviderConfig.name``. + provider_name: str + + +# Per-request callback that resolves a bearer token on demand for a BYOK +# provider (for example via Azure Managed Identity). The Copilot SDK takes no +# identity dependency: supply a callback backed by your own identity library. +# Never serialized — setting it makes the SDK send ``hasBearerTokenProvider`` on +# the wire and answer the runtime's ``providerToken.getToken`` requests. May be +# sync or async. +GetBearerToken = Callable[[ProviderTokenArgs], str | Awaitable[str]] + + class ProviderConfig(TypedDict, total=False): """Configuration for a custom API provider""" @@ -1113,6 +1138,12 @@ class ProviderConfig(TypedDict, total=False): # Overrides the resolved model's default max output tokens. When hit, the # model stops generating and returns a truncated response. max_output_tokens: int + # Per-request callback that resolves a bearer token on demand for this BYOK + # provider (for example via Azure Managed Identity). Never serialized — the + # SDK sends hasBearerTokenProvider: true on the wire and answers the + # runtime's providerToken.getToken requests with this callback's result. + # Mutually exclusive with api_key and bearer_token. + get_bearer_token: GetBearerToken class NamedProviderConfig(TypedDict, total=False): @@ -1139,6 +1170,11 @@ class NamedProviderConfig(TypedDict, total=False): bearer_token: str azure: AzureProviderOptions # Azure-specific options headers: dict[str, str] + # Per-request bearer-token callback for this named BYOK provider. Never + # serialized; the SDK sends hasBearerTokenProvider: true and answers the + # runtime's providerToken.getToken requests. Mutually exclusive with api_key + # and bearer_token. + get_bearer_token: GetBearerToken class ProviderModelConfig(TypedDict, total=False): @@ -1210,6 +1246,35 @@ def _canvas_handler_error(err: Exception) -> JsonRpcError: ) +class _BearerTokenProviderAdapter: + """Routes runtime ``providerToken.getToken`` requests to the matching + per-provider :data:`GetBearerToken` callback registered on the session. + + The runtime calls this once per outbound request for a BYOK provider that + declared ``hasBearerTokenProvider: true``; it does no caching, so the SDK + consumer's callback (typically backed by an identity library) owns + acquisition, caching, and refresh. + """ + + def __init__(self, session: CopilotSession) -> None: + self._session = session + + async def get_token(self, params: ProviderTokenAcquireRequest) -> ProviderTokenAcquireResult: + provider_name = params.provider_name + with self._session._bearer_token_providers_lock: + callback = self._session._bearer_token_providers.get(provider_name) + if callback is None: + raise JsonRpcError( + -32603, + f"No bearer-token provider registered for provider: {provider_name!r}", + ) + args: ProviderTokenArgs = {"provider_name": provider_name} + result = callback(args) + if inspect.isawaitable(result): + result = await result + return ProviderTokenAcquireResult(token=cast(str, result)) + + class CopilotSession: """ Represents a single conversation session with the Copilot CLI. @@ -1275,6 +1340,8 @@ def __init__( self._transform_callbacks_lock = threading.Lock() self._command_handlers: dict[str, CommandHandler] = {} self._command_handlers_lock = threading.Lock() + self._bearer_token_providers: dict[str, GetBearerToken] = {} + self._bearer_token_providers_lock = threading.Lock() self._elicitation_handler: ElicitationHandler | None = None self._elicitation_handler_lock = threading.Lock() self._capabilities: SessionCapabilities = {} @@ -2015,6 +2082,26 @@ def _register_commands(self, commands: list[CommandDefinition] | None) -> None: for cmd in commands: self._command_handlers[cmd.name] = cmd.handler + def _register_bearer_token_providers(self, providers: dict[str, GetBearerToken] | None) -> None: + """Register per-provider bearer-token callbacks for this session. + + The runtime never receives the callbacks themselves; the SDK strips them + from the provider config and instead sends ``hasBearerTokenProvider: + true``. When the runtime needs a token it issues a session-scoped + ``providerToken.getToken`` request, which the registered handler routes + to the matching per-provider callback. + + Args: + providers: Map of provider name -> callback, or None/empty to clear. + """ + with self._bearer_token_providers_lock: + self._bearer_token_providers.clear() + if not providers: + self._client_session_apis.provider_token = None + return + self._bearer_token_providers.update(providers) + self._client_session_apis.provider_token = _BearerTokenProviderAdapter(self) + def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> None: """Register the elicitation handler for this session. diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py new file mode 100644 index 000000000..28f9e0586 --- /dev/null +++ b/python/e2e/test_byok_bearer_token_provider_e2e.py @@ -0,0 +1,251 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""E2E coverage for the experimental BYOK bearer-token-provider surface. + +Mirrors ``nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts``. A BYOK +provider config may carry a ``get_bearer_token`` callback; the callback stays +entirely on the SDK/client side. The SDK strips it from the wire config, sets +the ``hasBearerTokenProvider`` flag, and the runtime calls back over the +session-scoped ``providerToken.getToken`` RPC before each outbound model +request, applying the returned token as the ``Authorization`` header. + +Like the other ``copilot_request_*`` tests, this one installs a client-global +``CopilotRequestHandler`` instead of using the CAPI proxy: the handler +fabricates the bootstrap (catalog/policy) responses and intercepts the +runtime's outbound BYOK request in-process, capturing the ``Authorization`` +header and returning a synthetic ``404``. It validates, against a real runtime: + 1. the callback's token reaches the model request as ``Authorization: Bearer ``; + 2. the runtime re-acquires a token per request (no runtime-side caching); + 3. per-provider dispatch routes each provider's turn to its own callback, and + the resulting token reaches that provider's endpoint. +""" + +from __future__ import annotations + +import re + +import httpx +import pytest +import pytest_asyncio + +from copilot import CopilotRequestContext, CopilotRequestHandler +from copilot.session import GetBearerToken, PermissionHandler + +from ._copilot_request_helpers import build_isolated_client, build_non_inference_response +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + +# Fake BYOK provider base URLs. These hosts are never actually dialed: the +# client-global request interceptor fully answers any request aimed at a +# ``.invalid`` host, so they only need to be syntactically valid, non-resolving +# URLs. Distinct hosts let the per-provider test assert routing by host. +PRIMARY_HOST = "byok-endpoint.invalid" +PRIMARY_BASE_URL = f"https://{PRIMARY_HOST}/v1" +RED_HOST = "byok-red.invalid" +RED_BASE_URL = f"https://{RED_HOST}/v1" +BLUE_HOST = "byok-blue.invalid" +BLUE_BASE_URL = f"https://{BLUE_HOST}/v1" + + +class _CapturingRequestHandler(CopilotRequestHandler): + """Client-global HTTP interceptor used in place of a real BYOK listener. + + The runtime invokes :meth:`send_request` for every model-layer HTTP request. + Requests aimed at a fake BYOK host are captured — recording the + ``Authorization`` header the runtime applied after calling the provider's + ``get_bearer_token`` callback over ``providerToken.getToken`` — and answered + with a synthetic ``404`` (non-retryable, so each outbound model request + yields exactly one capture). Every other request (CAPI bootstrap: model + catalog, policy, …) is fabricated locally so no real network or CAPI proxy + is involved. + """ + + def __init__(self) -> None: + # (host, authorization) for each captured BYOK request, in arrival order. + self.captures: list[tuple[str, str | None]] = [] + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + url = httpx.URL(request.url) + host = url.host + if host.endswith(".invalid"): + self.captures.append((host, request.headers.get("authorization"))) + return httpx.Response( + 404, + headers={"content-type": "application/json"}, + json={"error": {"message": "fake byok endpoint"}}, + request=request, + ) + return build_non_inference_response(str(request.url)) + + def reset(self) -> None: + self.captures.clear() + + def auth_headers(self) -> list[str]: + """The ``Authorization`` headers captured across BYOK requests, in order.""" + return [auth for (_host, auth) in self.captures if auth is not None] + + def auth_header_for_host(self, host: str) -> str | None: + """The ``Authorization`` header captured for requests aimed at ``host``.""" + for captured_host, auth in self.captures: + if captured_host == host: + return auth + return None + + +@pytest_asyncio.fixture(loop_scope="module") +async def bearer_fixture(ctx: E2ETestContext): + handler = _CapturingRequestHandler() + client = build_isolated_client(ctx, handler) + await client.start() + try: + yield client, handler + finally: + try: + await client.stop() + except Exception: + # Best-effort teardown during fixture cleanup. + pass + + +async def _run_turn(client, providers, models, selection_id: str, prompt: str) -> None: + """Drive one BYOK turn; the synthetic 404 errors the turn, which is expected.""" + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + model=selection_id, + providers=providers, + models=models, + ) + try: + # The interceptor always 404s, so the turn errors after the runtime has + # already sent the (token-bearing) request — which is all we assert on. + try: + await session.send_and_wait(prompt) + except Exception: + pass + finally: + try: + await session.disconnect() + except Exception: + # ignore disconnect errors for the fake BYOK endpoint + pass + + +class TestByokBearerTokenProvider: + async def test_applies_the_callbacks_token_as_the_authorization_header(self, bearer_fixture): + client, handler = bearer_fixture + handler.reset() + + sentinel = "sentinel-bearer-token-abc123" + calls = 0 + + async def get_bearer_token(args) -> str: + nonlocal calls + calls += 1 + return sentinel + + providers = [ + { + "name": "mi", + "type": "openai", + "wire_api": "completions", + "base_url": PRIMARY_BASE_URL, + "get_bearer_token": get_bearer_token, + } + ] + models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] + + await _run_turn(client, providers, models, "mi/default", "What is 5+5?") + + # The runtime acquired a token via the callback and applied it verbatim + # as the bearer credential on the outbound model request. + assert f"Bearer {sentinel}" in handler.auth_headers() + assert calls >= 1 + + async def test_reacquires_a_fresh_token_for_each_request(self, bearer_fixture): + client, handler = bearer_fixture + handler.reset() + + calls = 0 + + async def get_bearer_token(args) -> str: + nonlocal calls + calls += 1 + # A distinct token per acquisition proves the runtime re-invokes the + # callback per request rather than caching a previous token. + return f"rotating-token-{calls}" + + providers = [ + { + "name": "mi", + "type": "openai", + "wire_api": "completions", + "base_url": PRIMARY_BASE_URL, + "get_bearer_token": get_bearer_token, + } + ] + models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] + + await _run_turn(client, providers, models, "mi/default", "What is 1+1?") + await _run_turn(client, providers, models, "mi/default", "What is 2+2?") + + # Each outbound request carries a freshly-acquired, distinct token. + auths = handler.auth_headers() + assert len(auths) >= 2 + assert re.match(r"^Bearer rotating-token-\d+$", auths[0]) + assert re.match(r"^Bearer rotating-token-\d+$", auths[1]) + assert auths[0] != auths[1] + assert calls >= 2 + + async def test_dispatches_token_acquisition_per_provider(self, bearer_fixture): + client, handler = bearer_fixture + handler.reset() + + token_by_provider = {"red": "token-for-red", "blue": "token-for-blue"} + acquired_for: list[str] = [] + + def make_callback(provider_name: str) -> GetBearerToken: + async def callback(args) -> str: + # The runtime forwards the requesting provider's name so the + # client can dispatch to the right credential. + assert args["provider_name"] == provider_name + acquired_for.append(provider_name) + return token_by_provider[provider_name] + + return callback + + providers = [ + { + "name": "red", + "type": "openai", + "wire_api": "completions", + "base_url": RED_BASE_URL, + "get_bearer_token": make_callback("red"), + }, + { + "name": "blue", + "type": "openai", + "wire_api": "completions", + "base_url": BLUE_BASE_URL, + "get_bearer_token": make_callback("blue"), + }, + ] + models = [ + {"id": "default", "provider": "red", "wire_model": "byok-gpt-4o"}, + {"id": "default", "provider": "blue", "wire_model": "byok-gpt-4o"}, + ] + + await _run_turn(client, providers, models, "red/default", "What is 3+3?") + await _run_turn(client, providers, models, "blue/default", "What is 4+4?") + + # Each provider's turn was authenticated with its own token AND that + # token was delivered to that provider's endpoint, proving per-provider + # dispatch (not a single session-global credential). + assert handler.auth_header_for_host(RED_HOST) == f"Bearer {token_by_provider['red']}" + assert handler.auth_header_for_host(BLUE_HOST) == f"Bearer {token_by_provider['blue']}" + assert "red" in acquired_for + assert "blue" in acquired_for diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 018907d99..22fdc53d7 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -22,6 +22,9 @@ pub mod hooks; mod jsonrpc; /// Permission-policy helpers that produce a [`handler::PermissionHandler`]. pub mod permission; +/// BYOK bearer-token provider callbacks. +pub mod provider_token; +mod provider_token_dispatch; /// GitHub Copilot CLI binary resolution (env var, embedded, dev cache). pub(crate) mod resolve; mod router; @@ -72,6 +75,7 @@ pub(crate) use jsonrpc::{ JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes, }; pub use mode::{BUILTIN_TOOLS_ISOLATED, ClientMode, ToolSet}; +pub use provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs}; /// Re-exported JSON-RPC internals for integration tests (requires `test-support` feature). #[cfg(feature = "test-support")] diff --git a/rust/src/provider_token.rs b/rust/src/provider_token.rs new file mode 100644 index 000000000..f92715006 --- /dev/null +++ b/rust/src/provider_token.rs @@ -0,0 +1,105 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +//! BYOK bearer-token provider callbacks. +//! +//!

+//! +//! **Experimental.** These types are part of an experimental wire-protocol +//! surface and may change or be removed in future SDK or CLI releases. +//! +//!
+ +use std::future::Future; + +use async_trait::async_trait; + +/// Arguments passed to a BYOK bearer-token provider callback. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol +/// surface and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProviderTokenArgs { + /// Name of the BYOK provider needing a token. + /// + /// This is `"default"` for the singular whole-session provider, otherwise + /// the named provider's `name`. + pub provider_name: String, +} + +/// Error returned by a [`BearerTokenProvider`]. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol +/// surface and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BearerTokenError { + message: String, +} + +impl BearerTokenError { + /// Construct a bearer-token error with a human-readable message. + pub fn message(message: impl Into) -> Self { + Self { + message: message.into(), + } + } + + /// Return the human-readable error message. + pub fn as_str(&self) -> &str { + &self.message + } +} + +impl std::fmt::Display for BearerTokenError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for BearerTokenError {} + +impl From for BearerTokenError { + fn from(message: String) -> Self { + Self::message(message) + } +} + +impl From<&str> for BearerTokenError { + fn from(message: &str) -> Self { + Self::message(message) + } +} + +/// Provider-side callback used to acquire bearer tokens for BYOK providers. +/// +///
+/// +/// **Experimental.** This trait is part of an experimental wire-protocol +/// surface and may change or be removed in future SDK or CLI releases. +/// +///
+#[async_trait] +pub trait BearerTokenProvider: Send + Sync { + /// Acquire a bearer token without the `Bearer ` prefix. + async fn get_token(&self, args: ProviderTokenArgs) -> Result; +} + +#[async_trait] +impl BearerTokenProvider for F +where + F: Fn(ProviderTokenArgs) -> Fut + Send + Sync, + Fut: Future> + Send, +{ + async fn get_token(&self, args: ProviderTokenArgs) -> Result { + (self)(args).await + } +} diff --git a/rust/src/provider_token_dispatch.rs b/rust/src/provider_token_dispatch.rs new file mode 100644 index 000000000..c100443cd --- /dev/null +++ b/rust/src/provider_token_dispatch.rs @@ -0,0 +1,157 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +//! Inbound `providerToken.*` JSON-RPC request dispatch helpers. + +use std::collections::HashMap; +use std::sync::Arc; + +use serde::Serialize; +use serde_json::Value; +use tracing::warn; + +use crate::generated::api_types::{ + ProviderTokenAcquireRequest, ProviderTokenAcquireResult, rpc_methods, +}; +use crate::provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs}; +use crate::{Client, JsonRpcRequest, JsonRpcResponse, error_codes}; + +async fn respond(client: &Client, request_id: u64, result: T) { + let value = match serde_json::to_value(&result) { + Ok(value) => value, + Err(error) => { + warn!(error = %error, "failed to serialize provider token response"); + send_error( + client, + request_id, + error_codes::INTERNAL_ERROR, + "serialization failure", + ) + .await; + return; + } + }; + + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: Some(value), + error: None, + }) + .await; +} + +async fn send_error(client: &Client, request_id: u64, code: i32, message: &str) { + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: None, + error: Some(crate::JsonRpcError { + code, + message: message.to_string(), + data: None, + }), + }) + .await; +} + +async fn parse_params( + client: &Client, + request: &JsonRpcRequest, +) -> Option { + let params = request + .params + .as_ref() + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + match serde_json::from_value(params) { + Ok(params) => Some(params), + Err(error) => { + send_error( + client, + request.id, + error_codes::INVALID_PARAMS, + &format!("invalid params: {error}"), + ) + .await; + None + } + } +} + +fn token_provider_or_err( + providers: &HashMap>, + provider_name: &str, +) -> Result, BearerTokenError> { + providers.get(provider_name).cloned().ok_or_else(|| { + BearerTokenError::message(format!( + "No bearer-token provider installed for BYOK provider {provider_name:?}" + )) + }) +} + +async fn get_token( + client: &Client, + providers: &HashMap>, + request: JsonRpcRequest, +) { + let Some(params) = parse_params::(client, &request).await else { + return; + }; + + let token_provider = match token_provider_or_err(providers, ¶ms.provider_name) { + Ok(provider) => provider, + Err(error) => { + send_error( + client, + request.id, + error_codes::INTERNAL_ERROR, + &error.to_string(), + ) + .await; + return; + } + }; + + match token_provider + .get_token(ProviderTokenArgs { + provider_name: params.provider_name, + }) + .await + { + Ok(token) => respond(client, request.id, ProviderTokenAcquireResult { token }).await, + Err(error) => { + send_error( + client, + request.id, + error_codes::INTERNAL_ERROR, + &format!("Bearer-token provider failed: {error}"), + ) + .await; + } + } +} + +pub(crate) async fn dispatch( + client: &Client, + providers: &HashMap>, + request: JsonRpcRequest, +) { + let method = request.method.as_str(); + match method { + rpc_methods::PROVIDERTOKEN_GETTOKEN => get_token(client, providers, request).await, + _ => { + warn!(method = %method, "unknown providerToken.* method"); + send_error( + client, + request.id, + error_codes::METHOD_NOT_FOUND, + &format!("unknown method: {method}"), + ) + .await; + } + } +} diff --git a/rust/src/session.rs b/rust/src/session.rs index fed6705da..18b91b437 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -21,6 +21,7 @@ use crate::handler::{ PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; +use crate::provider_token::BearerTokenProvider; use crate::session_fs::SessionFsProvider; use crate::trace_context::inject_trace_context; use crate::transforms::SystemMessageTransform; @@ -893,6 +894,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1011,6 +1013,7 @@ impl Client { command_handlers, canvas_handler, session_fs_provider, + bearer_token_providers, channels, idle_waiter.clone(), capabilities.clone(), @@ -1149,6 +1152,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1183,6 +1187,7 @@ impl Client { command_handlers, canvas_handler, session_fs_provider, + bearer_token_providers, channels, idle_waiter.clone(), capabilities.clone(), @@ -1391,6 +1396,7 @@ fn spawn_event_loop( command_handlers: Arc, canvas_handler: Option>, session_fs_provider: Option>, + bearer_token_providers: HashMap>, channels: crate::router::SessionChannels, idle_waiter: Arc>>, capabilities: Arc>, @@ -1432,6 +1438,7 @@ fn spawn_event_loop( transforms: transforms.as_deref(), canvas_handler: canvas_handler.as_ref(), session_fs_provider: session_fs_provider.as_ref(), + bearer_token_providers: &bearer_token_providers, }; handle_request(&session_id, ctx, request).await; } @@ -2010,6 +2017,7 @@ struct RequestDispatchContext<'a> { transforms: Option<&'a dyn SystemMessageTransform>, canvas_handler: Option<&'a Arc>, session_fs_provider: Option<&'a Arc>, + bearer_token_providers: &'a HashMap>, } /// Process a JSON-RPC request from the CLI. @@ -2025,6 +2033,7 @@ async fn handle_request( let transforms = ctx.transforms; let canvas_handler = ctx.canvas_handler; let session_fs_provider = ctx.session_fs_provider; + let bearer_token_providers = ctx.bearer_token_providers; if request.method.starts_with("sessionFs.") { crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await; @@ -2036,6 +2045,11 @@ async fn handle_request( return; } + if request.method == crate::generated::api_types::rpc_methods::PROVIDERTOKEN_GETTOKEN { + crate::provider_token_dispatch::dispatch(client, bearer_token_providers, request).await; + return; + } + match request.method.as_str() { "hooks.invoke" => { let params = request.params.as_ref(); diff --git a/rust/src/types.rs b/rust/src/types.rs index 668eaccc1..d62e20f70 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -28,6 +28,7 @@ use crate::handler::{ UserInputHandler, }; use crate::hooks::SessionHooks; +use crate::provider_token::BearerTokenProvider; pub use crate::session_fs::{ DirEntry, DirEntryKind, FileInfo, FsError, SessionFsCapabilities, SessionFsConfig, SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, @@ -1021,7 +1022,7 @@ pub struct McpHttpServerConfig { /// Routes session requests through an alternative model provider /// (OpenAI-compatible, Azure, Anthropic, or local) instead of GitHub /// Copilot's default routing. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct ProviderConfig { @@ -1049,6 +1050,12 @@ pub struct ProviderConfig { /// API key. Takes precedence over `api_key` when both are set. #[serde(default, skip_serializing_if = "Option::is_none")] pub bearer_token: Option, + /// **Experimental.** Callback used to acquire a bearer token before each + /// outbound request to this provider. + #[serde(skip)] + pub get_bearer_token: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. #[serde(default, skip_serializing_if = "Option::is_none")] pub azure: Option, @@ -1080,6 +1087,30 @@ pub struct ProviderConfig { pub max_output_tokens: Option, } +impl std::fmt::Debug for ProviderConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProviderConfig") + .field("provider_type", &self.provider_type) + .field("wire_api", &self.wire_api) + .field("transport", &self.transport) + .field("base_url", &self.base_url) + .field("api_key", &self.api_key) + .field("bearer_token", &self.bearer_token) + .field( + "get_bearer_token", + &self.get_bearer_token.as_ref().map(|_| ""), + ) + .field("has_bearer_token_provider", &self.has_bearer_token_provider) + .field("azure", &self.azure) + .field("headers", &self.headers) + .field("model_id", &self.model_id) + .field("wire_model", &self.wire_model) + .field("max_prompt_tokens", &self.max_prompt_tokens) + .field("max_output_tokens", &self.max_output_tokens) + .finish() + } +} + impl ProviderConfig { /// Construct a [`ProviderConfig`] with the required `base_url` set; /// all other fields default to unset. @@ -1122,6 +1153,16 @@ impl ProviderConfig { self } + /// Set the callback used to acquire a bearer token before each outbound + /// request to this provider. + /// + /// **Experimental.** This method is part of an experimental wire-protocol + /// surface and may change or be removed in a future release. + pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { + self.get_bearer_token = Some(provider); + self + } + /// Set Azure-specific options. pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { self.azure = Some(azure); @@ -1223,7 +1264,7 @@ pub struct AzureProviderOptions { /// default Copilot routing and exposes these providers' models alongside /// it. Models are attached via [`ProviderModelConfig`], which references a /// provider by [`name`](Self::name). -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct NamedProviderConfig { @@ -1247,6 +1288,12 @@ pub struct NamedProviderConfig { /// directly. Takes precedence over `api_key` when both are set. #[serde(default, skip_serializing_if = "Option::is_none")] pub bearer_token: Option, + /// **Experimental.** Callback used to acquire a bearer token before each + /// outbound request to this provider. + #[serde(skip)] + pub get_bearer_token: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. #[serde(default, skip_serializing_if = "Option::is_none")] pub azure: Option, @@ -1255,6 +1302,26 @@ pub struct NamedProviderConfig { pub headers: Option>, } +impl std::fmt::Debug for NamedProviderConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NamedProviderConfig") + .field("name", &self.name) + .field("provider_type", &self.provider_type) + .field("wire_api", &self.wire_api) + .field("base_url", &self.base_url) + .field("api_key", &self.api_key) + .field("bearer_token", &self.bearer_token) + .field( + "get_bearer_token", + &self.get_bearer_token.as_ref().map(|_| ""), + ) + .field("has_bearer_token_provider", &self.has_bearer_token_provider) + .field("azure", &self.azure) + .field("headers", &self.headers) + .finish() + } +} + impl NamedProviderConfig { /// Construct a [`NamedProviderConfig`] with the required `name` and /// `base_url` set; all other fields default to unset. @@ -1291,6 +1358,16 @@ impl NamedProviderConfig { self } + /// Set the callback used to acquire a bearer token before each outbound + /// request to this provider. + /// + /// **Experimental.** This method is part of an experimental wire-protocol + /// surface and may change or be removed in a future release. + pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { + self.get_bearer_token = Some(provider); + self + } + /// Set Azure-specific options. pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { self.azure = Some(azure); @@ -1304,6 +1381,31 @@ impl NamedProviderConfig { } } +fn prepare_bearer_token_providers( + provider: &mut Option, + providers: &mut Option>, +) -> HashMap> { + let mut bearer_token_providers = HashMap::new(); + + if let Some(provider) = provider.as_mut() + && let Some(token_provider) = provider.get_bearer_token.take() + { + provider.has_bearer_token_provider = Some(true); + bearer_token_providers.insert("default".to_string(), token_provider); + } + + if let Some(providers) = providers.as_mut() { + for provider in providers { + if let Some(token_provider) = provider.get_bearer_token.take() { + provider.has_bearer_token_provider = Some(true); + bearer_token_providers.insert(provider.name.clone(), token_provider); + } + } + } + + bearer_token_providers +} + /// A BYOK model definition in the multi-provider registry. /// /// **Experimental.** Multi-provider BYOK configuration is part of an @@ -1919,6 +2021,7 @@ pub(crate) struct SessionConfigRuntime { pub tool_handlers: HashMap>, pub canvas_handler: Option>, pub session_fs_provider: Option>, + pub bearer_token_providers: HashMap>, pub commands: Option>, } @@ -1970,6 +2073,8 @@ impl SessionConfig { }); let wire_canvases = self.canvases.clone(); let canvas_handler = self.canvas_handler.clone(); + let bearer_token_providers = + prepare_bearer_token_providers(&mut self.provider, &mut self.providers); let wire = crate::wire::SessionCreateWire { session_id, @@ -2046,6 +2151,7 @@ impl SessionConfig { tool_handlers, canvas_handler, session_fs_provider: self.session_fs_provider, + bearer_token_providers, commands: self.commands, }; @@ -2926,6 +3032,8 @@ impl ResumeSessionConfig { }); let wire_canvases = self.canvases.clone(); let canvas_handler = self.canvas_handler.clone(); + let bearer_token_providers = + prepare_bearer_token_providers(&mut self.provider, &mut self.providers); let wire = crate::wire::SessionResumeWire { session_id: self.session_id, @@ -3003,6 +3111,7 @@ impl ResumeSessionConfig { tool_handlers, canvas_handler, session_fs_provider: self.session_fs_provider, + bearer_token_providers, commands: self.commands, }; diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs index c46630e69..59b83ab27 100644 --- a/rust/tests/e2e.rs +++ b/rust/tests/e2e.rs @@ -7,6 +7,8 @@ mod abort; mod ask_user; #[path = "e2e/builtin_tools.rs"] mod builtin_tools; +#[path = "e2e/byok_bearer_token_provider.rs"] +mod byok_bearer_token_provider; #[path = "e2e/canvas.rs"] mod canvas; #[path = "e2e/client.rs"] diff --git a/rust/tests/e2e/byok_bearer_token_provider.rs b/rust/tests/e2e/byok_bearer_token_provider.rs new file mode 100644 index 000000000..c3cd9ef4b --- /dev/null +++ b/rust/tests/e2e/byok_bearer_token_provider.rs @@ -0,0 +1,314 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use bytes::Bytes; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::{ + BearerTokenError, CopilotHttpRequest, CopilotHttpResponse, CopilotRequestContext, + CopilotRequestError, CopilotRequestHandler, MessageOptions, NamedProviderConfig, + ProviderModelConfig, ProviderTokenArgs, SessionConfig, +}; +use http::HeaderMap; + +use super::support::with_e2e_context_no_snapshot; + +const PRIMARY_BASE_URL: &str = "https://byok-endpoint.invalid/v1"; +const RED_HOST: &str = "byok-red.invalid"; +const RED_BASE_URL: &str = "https://byok-red.invalid/v1"; +const BLUE_HOST: &str = "byok-blue.invalid"; +const BLUE_BASE_URL: &str = "https://byok-blue.invalid/v1"; + +#[derive(Debug, Clone)] +struct CapturedRequest { + host: String, + authorization: Option, +} + +#[derive(Default)] +struct CapturingRequestHandler { + captures: std::sync::Mutex>, +} + +impl CapturingRequestHandler { + fn auth_headers(&self) -> Vec { + self.captures + .lock() + .unwrap() + .iter() + .filter_map(|capture| capture.authorization.clone()) + .collect() + } + + fn auth_header_for_host(&self, host: &str) -> Option { + self.captures + .lock() + .unwrap() + .iter() + .find(|capture| capture.host == host) + .and_then(|capture| capture.authorization.clone()) + } + + fn reset(&self) { + self.captures.lock().unwrap().clear(); + } +} + +#[async_trait] +impl CopilotRequestHandler for CapturingRequestHandler { + async fn send_request( + &self, + request: CopilotHttpRequest, + _ctx: &CopilotRequestContext, + ) -> Result { + let uri: http::Uri = request + .url + .parse() + .map_err(|error| CopilotRequestError::message(format!("invalid URL: {error}")))?; + if let Some(host) = uri.host() + && host.ends_with(".invalid") + { + let authorization = request + .headers + .get("authorization") + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + self.captures.lock().unwrap().push(CapturedRequest { + host: host.to_string(), + authorization, + }); + return Ok(json_response( + 404, + br#"{"error":{"message":"fake byok endpoint"}}"#.to_vec(), + )); + } + + Ok(synth_non_inference_response(&request.url)) + } +} + +fn json_response(status: u16, body: Vec) -> CopilotHttpResponse { + let mut headers = HeaderMap::new(); + headers.insert( + "content-type", + http::HeaderValue::from_static("application/json"), + ); + let body = futures_util::stream::iter([Ok::(Bytes::from(body))]); + CopilotHttpResponse::new(status, None, headers, Box::pin(body)) +} + +fn synth_non_inference_response(url: &str) -> CopilotHttpResponse { + let lower = url.to_lowercase(); + if lower.ends_with("/models") { + return json_response( + 200, + br#"{"data":[{"id":"gpt-4o","name":"GPT-4o","object":"model","vendor":"OpenAI","version":"1","preview":false,"model_picker_enabled":true,"capabilities":{"type":"chat","family":"gpt-4o","tokenizer":"o200k_base","limits":{"max_context_window_tokens":128000,"max_output_tokens":4096},"supports":{"streaming":true,"tool_calls":true,"parallel_tool_calls":true}}}]}"# + .to_vec(), + ); + } + if lower.contains("/models/session") { + return json_response(200, b"{}".to_vec()); + } + if lower.contains("/policy") { + return json_response(200, br#"{"state":"enabled"}"#.to_vec()); + } + json_response(200, b"{}".to_vec()) +} + +async fn run_turn( + client: &github_copilot_sdk::Client, + providers: Vec, + models: Vec, + selection_id: &str, + prompt: &str, +) { + let session = client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_model(selection_id) + .with_providers(providers) + .with_models(models), + ) + .await + .expect("create session"); + let _ = session.send_and_wait(MessageOptions::new(prompt)).await; + let _ = session.disconnect().await; +} + +#[tokio::test] +async fn callback_token_is_applied_as_authorization_header() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CapturingRequestHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + handler.reset(); + + let calls = Arc::new(AtomicUsize::new(0)); + let callback_calls = calls.clone(); + let providers = vec![ + NamedProviderConfig::new("mi", PRIMARY_BASE_URL) + .with_provider_type("openai") + .with_wire_api("completions") + .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + let callback_calls = callback_calls.clone(); + async move { + callback_calls.fetch_add(1, Ordering::SeqCst); + Ok::<_, BearerTokenError>("sentinel-bearer-token-abc123".to_string()) + } + })), + ]; + let models = + vec![ProviderModelConfig::new("default", "mi").with_wire_model("byok-gpt-4o")]; + + run_turn(&client, providers, models, "mi/default", "What is 5+5?").await; + + assert!( + calls.load(Ordering::SeqCst) >= 1, + "expected callback to be invoked" + ); + // Validate the captured Authorization header is the final assertion. + assert!( + handler + .auth_headers() + .contains(&"Bearer sentinel-bearer-token-abc123".to_string()), + "expected captured Authorization headers to include the sentinel token, got {:?}", + handler.auth_headers() + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn reacquires_a_fresh_token_for_each_request() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CapturingRequestHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + handler.reset(); + + let calls = Arc::new(AtomicUsize::new(0)); + let callback_calls = calls.clone(); + let providers = vec![ + NamedProviderConfig::new("mi", PRIMARY_BASE_URL) + .with_provider_type("openai") + .with_wire_api("completions") + .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + let callback_calls = callback_calls.clone(); + async move { + let call = callback_calls.fetch_add(1, Ordering::SeqCst) + 1; + Ok::<_, BearerTokenError>(format!("rotating-token-{call}")) + } + })), + ]; + let models = + vec![ProviderModelConfig::new("default", "mi").with_wire_model("byok-gpt-4o")]; + + run_turn( + &client, + providers.clone(), + models.clone(), + "mi/default", + "What is 1+1?", + ) + .await; + run_turn(&client, providers, models, "mi/default", "What is 2+2?").await; + + let auths = handler.auth_headers(); + assert!( + auths.len() >= 2, + "expected at least 2 captured Authorization headers, got {auths:?}" + ); + assert!( + auths[0].starts_with("Bearer rotating-token-") + && auths[1].starts_with("Bearer rotating-token-"), + "expected rotating-token bearer headers, got {auths:?}" + ); + assert!( + calls.load(Ordering::SeqCst) >= 2, + "expected callback to be invoked at least twice" + ); + // Validate the captured Authorization header is the final assertion. + assert_ne!( + auths[0], auths[1], + "expected distinct tokens per request, both were {:?}", + auths[0] + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn dispatches_token_acquisition_per_provider() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CapturingRequestHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + handler.reset(); + + let acquired_for = Arc::new(std::sync::Mutex::new(Vec::new())); + let make_provider = + |name: &'static str, base_url: &'static str, token: &'static str| { + let acquired_for = acquired_for.clone(); + NamedProviderConfig::new(name, base_url) + .with_provider_type("openai") + .with_wire_api("completions") + .with_get_bearer_token(Arc::new(move |args: ProviderTokenArgs| { + let acquired_for = acquired_for.clone(); + async move { + assert_eq!(args.provider_name, name); + acquired_for.lock().unwrap().push(name.to_string()); + Ok::<_, BearerTokenError>(token.to_string()) + } + })) + }; + let providers = vec![ + make_provider("red", RED_BASE_URL, "token-for-red"), + make_provider("blue", BLUE_BASE_URL, "token-for-blue"), + ]; + let models = vec![ + ProviderModelConfig::new("default", "red").with_wire_model("byok-gpt-4o"), + ProviderModelConfig::new("default", "blue").with_wire_model("byok-gpt-4o"), + ]; + + run_turn( + &client, + providers.clone(), + models.clone(), + "red/default", + "What is 3+3?", + ) + .await; + run_turn(&client, providers, models, "blue/default", "What is 4+4?").await; + + let acquired = acquired_for.lock().unwrap().clone(); + assert!(acquired.contains(&"red".to_string())); + assert!(acquired.contains(&"blue".to_string())); + assert_eq!( + handler.auth_header_for_host(RED_HOST).as_deref(), + Some("Bearer token-for-red") + ); + // Validate the captured Authorization header is the final assertion. + assert_eq!( + handler.auth_header_for_host(BLUE_HOST).as_deref(), + Some("Bearer token-for-blue") + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +}