elicitationHandler = new AtomicReference<>();
@@ -1347,6 +1349,33 @@ void registerElicitationHandler(ElicitationHandler handler) {
elicitationHandler.set(handler);
}
+ /**
+ * Registers bearer-token provider callbacks for this session.
+ *
+ * Called internally when creating or resuming a session with BYOK providers
+ * that use managed-identity token callbacks.
+ *
+ * @param providers
+ * the callbacks keyed by provider name
+ */
+ void registerBearerTokenProviders(Map providers) {
+ bearerTokenProviders.clear();
+ if (providers != null) {
+ bearerTokenProviders.putAll(providers);
+ }
+ }
+
+ /**
+ * Gets the bearer-token provider callback for the given provider name.
+ *
+ * @param providerName
+ * the provider name
+ * @return the registered callback, or {@code null} if none is registered
+ */
+ GetBearerToken getBearerTokenProvider(String providerName) {
+ return bearerTokenProviders.get(providerName);
+ }
+
/**
* Registers an exit-plan-mode handler for this session.
*
diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java
index 391f270db..b62e8c582 100644
--- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java
+++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java
@@ -19,6 +19,8 @@
import com.github.copilot.generated.SessionEvent;
import com.github.copilot.rpc.AutoModeSwitchRequest;
import com.github.copilot.rpc.ExitPlanModeRequest;
+import com.github.copilot.rpc.GetBearerToken;
+import com.github.copilot.rpc.ProviderTokenArgs;
import com.github.copilot.rpc.PermissionRequestResult;
import com.github.copilot.rpc.PermissionRequestResultKind;
import com.github.copilot.rpc.SessionLifecycleEvent;
@@ -88,6 +90,8 @@ void registerHandlers(JsonRpcClient rpc) {
rpc.registerMethodHandler("hooks.invoke", (requestId, params) -> handleHooksInvoke(rpc, requestId, params));
rpc.registerMethodHandler("systemMessage.transform",
(requestId, params) -> handleSystemMessageTransform(rpc, requestId, params));
+ rpc.registerMethodHandler("providerToken.getToken",
+ (requestId, params) -> handleProviderTokenGetToken(rpc, requestId, params));
}
private void handleSessionEvent(JsonNode params) {
@@ -300,6 +304,68 @@ private void handleUserInputRequest(JsonRpcClient rpc, String requestId, JsonNod
});
}
+ private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, JsonNode params) {
+ LOG.fine("Received providerToken.getToken: " + params);
+ runAsync(() -> {
+ final long requestIdLong = parseRequestId(requestId, "providerToken.getToken");
+ if (requestIdLong == -1) {
+ return;
+ }
+ try {
+ String sessionId = params.get("sessionId").asText();
+ String providerName = params.get("providerName").asText();
+
+ CopilotSession session = sessions.get(sessionId);
+ if (session == null) {
+ rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId);
+ return;
+ }
+
+ GetBearerToken provider = session.getBearerTokenProvider(providerName);
+ if (provider == null) {
+ rpc.sendErrorResponse(requestIdLong, -32603,
+ "No bearer-token provider registered for provider " + providerName);
+ return;
+ }
+
+ CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName));
+ if (tokenFuture == null) {
+ rpc.sendErrorResponse(requestIdLong, -32603,
+ "Bearer-token provider returned null future for provider " + providerName);
+ return;
+ }
+
+ tokenFuture.thenAccept(token -> {
+ try {
+ if (token == null) {
+ rpc.sendErrorResponse(requestIdLong, -32603,
+ "Bearer-token provider returned null token for provider " + providerName);
+ return;
+ }
+ rpc.sendResponse(requestIdLong, Map.of("token", token));
+ } catch (IOException e) {
+ LOG.log(Level.SEVERE, "Error sending provider token response", e);
+ }
+ }).exceptionally(ex -> {
+ LOG.log(Level.WARNING, "Bearer-token provider exception", ex);
+ try {
+ rpc.sendErrorResponse(requestIdLong, -32603, "Bearer-token provider error: " + ex.getMessage());
+ } catch (IOException e) {
+ LOG.log(Level.SEVERE, "Error sending provider token error", e);
+ }
+ return null;
+ });
+ } catch (Exception e) {
+ LOG.log(Level.SEVERE, "Error handling providerToken.getToken", e);
+ try {
+ rpc.sendErrorResponse(requestIdLong, -32603, "Provider token handler error: " + e.getMessage());
+ } catch (IOException ioException) {
+ LOG.log(Level.SEVERE, "Error sending provider token handler error", ioException);
+ }
+ }
+ });
+ }
+
private void handleExitPlanModeRequest(JsonRpcClient rpc, String requestId, JsonNode params) {
runAsync(() -> {
final long requestIdLong = parseRequestId(requestId, "exitPlanMode.request");
diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java
index 072cf480d..943894d28 100644
--- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java
+++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java
@@ -5,11 +5,15 @@
package com.github.copilot;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import com.github.copilot.rpc.CreateSessionRequest;
+import com.github.copilot.rpc.ProviderConfig;
+import com.github.copilot.rpc.NamedProviderConfig;
+import com.github.copilot.rpc.GetBearerToken;
import com.github.copilot.rpc.CommandWireDefinition;
import com.github.copilot.rpc.ResumeSessionConfig;
import com.github.copilot.rpc.ResumeSessionRequest;
@@ -331,6 +335,11 @@ static void configureSession(CopilotSession session, SessionConfig config) {
if (config.getOnElicitationRequest() != null) {
session.registerElicitationHandler(config.getOnElicitationRequest());
}
+ Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(),
+ config.getProviders());
+ if (!bearerTokenProviders.isEmpty()) {
+ session.registerBearerTokenProviders(bearerTokenProviders);
+ }
if (config.getOnExitPlanMode() != null) {
session.registerExitPlanModeHandler(config.getOnExitPlanMode());
}
@@ -373,6 +382,11 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config)
if (config.getOnElicitationRequest() != null) {
session.registerElicitationHandler(config.getOnElicitationRequest());
}
+ Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(),
+ config.getProviders());
+ if (!bearerTokenProviders.isEmpty()) {
+ session.registerBearerTokenProviders(bearerTokenProviders);
+ }
if (config.getOnExitPlanMode() != null) {
session.registerExitPlanModeHandler(config.getOnExitPlanMode());
}
@@ -383,4 +397,21 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config)
session.on(config.getOnEvent());
}
}
+
+ private static Map collectBearerTokenProviders(ProviderConfig provider,
+ List providers) {
+ Map bearerTokenProviders = new HashMap<>();
+ if (provider != null && provider.getGetBearerToken() != null) {
+ bearerTokenProviders.put("default", provider.getGetBearerToken());
+ }
+ if (providers != null) {
+ for (NamedProviderConfig namedProvider : providers) {
+ if (namedProvider != null && namedProvider.getName() != null
+ && namedProvider.getGetBearerToken() != null) {
+ bearerTokenProviders.put(namedProvider.getName(), namedProvider.getGetBearerToken());
+ }
+ }
+ }
+ return bearerTokenProviders;
+ }
}
diff --git a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java b/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java
new file mode 100644
index 000000000..27ec7f09c
--- /dev/null
+++ b/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java
@@ -0,0 +1,40 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ *--------------------------------------------------------------------------------------------*/
+
+package com.github.copilot.rpc;
+
+import java.util.concurrent.CompletableFuture;
+
+import com.github.copilot.CopilotExperimental;
+
+/**
+ * Functional interface for supplying per-provider bearer tokens for BYOK
+ * provider requests.
+ *
+ * The callback returns the raw token without a {@code Bearer } prefix. The SDK
+ * keeps this callback client-side and the runtime requests a token via the
+ * session-scoped {@code providerToken.getToken} RPC before each outbound model
+ * request.
+ *
+ * Experimental. This managed-identity surface may change or be
+ * removed in future SDK or CLI releases.
+ *
+ * @see ProviderConfig#setGetBearerToken(GetBearerToken)
+ * @see NamedProviderConfig#setGetBearerToken(GetBearerToken)
+ * @since 1.0.0
+ */
+@CopilotExperimental
+@FunctionalInterface
+public interface GetBearerToken {
+
+ /**
+ * Gets a bearer token for the provider identified by {@code args}.
+ *
+ * @param args
+ * the provider token request arguments
+ * @return a future that completes with the raw token, without a {@code Bearer }
+ * prefix
+ */
+ CompletableFuture getToken(ProviderTokenArgs args);
+}
diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java
index dbc157739..2bdf2678f 100644
--- a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java
+++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java
@@ -7,6 +7,7 @@
import java.util.Collections;
import java.util.Map;
+import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -59,6 +60,9 @@ public class NamedProviderConfig {
@JsonProperty("bearerToken")
private String bearerToken;
+ @JsonIgnore
+ private GetBearerToken getBearerToken;
+
@JsonProperty("azure")
private AzureOptions azure;
@@ -212,6 +216,39 @@ public NamedProviderConfig setBearerToken(String bearerToken) {
return this;
}
+ /**
+ * Gets the bearer-token provider callback.
+ *
+ * @return the bearer-token provider callback, or {@code null} if not set
+ */
+ public GetBearerToken getGetBearerToken() {
+ return getBearerToken;
+ }
+
+ /**
+ * Sets a callback that supplies bearer tokens for outbound provider requests.
+ *
+ * Experimental. The callback stays SDK-side and is not
+ * serialized. Instead, the runtime receives a {@code hasBearerTokenProvider}
+ * flag and calls back over the session-scoped {@code providerToken.getToken}
+ * RPC before each model request. Return the raw token without a {@code Bearer }
+ * prefix.
+ *
+ * @param getBearerToken
+ * the bearer-token provider callback
+ * @return this config for method chaining
+ */
+ public NamedProviderConfig setGetBearerToken(GetBearerToken getBearerToken) {
+ this.getBearerToken = getBearerToken;
+ return this;
+ }
+
+ @JsonProperty("hasBearerTokenProvider")
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ Boolean hasBearerTokenProviderWireFlag() {
+ return getBearerToken != null ? Boolean.TRUE : null;
+ }
+
/**
* Gets the Azure-specific options.
*
diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java
index 8ba492ed9..ae59e7ead 100644
--- a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java
+++ b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java
@@ -56,6 +56,9 @@ public class ProviderConfig {
@JsonProperty("bearerToken")
private String bearerToken;
+ @JsonIgnore
+ private GetBearerToken getBearerToken;
+
@JsonProperty("azure")
private AzureOptions azure;
@@ -222,6 +225,39 @@ public ProviderConfig setBearerToken(String bearerToken) {
return this;
}
+ /**
+ * Gets the bearer-token provider callback.
+ *
+ * @return the bearer-token provider callback, or {@code null} if not set
+ */
+ public GetBearerToken getGetBearerToken() {
+ return getBearerToken;
+ }
+
+ /**
+ * Sets a callback that supplies bearer tokens for outbound provider requests.
+ *
+ * Experimental. The callback stays SDK-side and is not
+ * serialized. Instead, the runtime receives a {@code hasBearerTokenProvider}
+ * flag and calls back over the session-scoped {@code providerToken.getToken}
+ * RPC before each model request. Return the raw token without a {@code Bearer }
+ * prefix.
+ *
+ * @param getBearerToken
+ * the bearer-token provider callback
+ * @return this config for method chaining
+ */
+ public ProviderConfig setGetBearerToken(GetBearerToken getBearerToken) {
+ this.getBearerToken = getBearerToken;
+ return this;
+ }
+
+ @JsonProperty("hasBearerTokenProvider")
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ Boolean hasBearerTokenProviderWireFlag() {
+ return getBearerToken != null ? Boolean.TRUE : null;
+ }
+
/**
* Gets the Azure-specific options.
*
diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java
new file mode 100644
index 000000000..3866cc0ad
--- /dev/null
+++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java
@@ -0,0 +1,63 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ *--------------------------------------------------------------------------------------------*/
+
+package com.github.copilot.rpc;
+
+import com.github.copilot.CopilotExperimental;
+
+/**
+ * Arguments passed to a BYOK bearer-token provider callback.
+ *
+ * Experimental. This managed-identity surface may change or be
+ * removed in future SDK or CLI releases.
+ *
+ * @since 1.0.0
+ */
+@CopilotExperimental
+public class ProviderTokenArgs {
+
+ private String providerName;
+
+ /**
+ * Creates an empty argument object.
+ */
+ public ProviderTokenArgs() {
+ }
+
+ /**
+ * Creates argument object for the named provider.
+ *
+ * @param providerName
+ * the name of the BYOK provider needing a token; {@code "default"}
+ * for the singular whole-session provider, otherwise the named
+ * provider's {@code name}
+ */
+ public ProviderTokenArgs(String providerName) {
+ this.providerName = providerName;
+ }
+
+ /**
+ * Gets the name of the BYOK provider needing a token.
+ *
+ * The value is {@code "default"} for the singular whole-session provider,
+ * otherwise the named provider's {@code name}.
+ *
+ * @return the provider name
+ */
+ public String getProviderName() {
+ return providerName;
+ }
+
+ /**
+ * Sets the name of the BYOK provider needing a token.
+ *
+ * @param providerName
+ * the provider name
+ * @return this args instance for method chaining
+ */
+ public ProviderTokenArgs setProviderName(String providerName) {
+ this.providerName = providerName;
+ return this;
+ }
+}
diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java
new file mode 100644
index 000000000..253ce136c
--- /dev/null
+++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java
@@ -0,0 +1,274 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ *--------------------------------------------------------------------------------------------*/
+
+package com.github.copilot;
+
+import static com.github.copilot.CopilotRequestTestSupport.buildNonInferenceResponse;
+import static com.github.copilot.CopilotRequestTestSupport.newLlmClient;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpHeaders;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.net.ssl.SSLSession;
+
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import com.github.copilot.rpc.GetBearerToken;
+import com.github.copilot.rpc.MessageOptions;
+import com.github.copilot.rpc.NamedProviderConfig;
+import com.github.copilot.rpc.PermissionHandler;
+import com.github.copilot.rpc.ProviderModelConfig;
+import com.github.copilot.rpc.SessionConfig;
+
+/**
+ * End-to-end coverage for the experimental BYOK bearer-token-provider surface
+ * ({@code getBearerToken} on a provider config). The callback stays entirely on
+ * the SDK/client side: the SDK keeps it off the wire, sends only the
+ * {@code hasBearerTokenProvider} flag, and the runtime calls back over the
+ * session-scoped {@code providerToken.getToken} RPC before each outbound model
+ * request.
+ */
+public class ByokBearerTokenProviderE2ETest {
+
+ private static final String PRIMARY_HOST = "byok-endpoint.invalid";
+ private static final String PRIMARY_BASE_URL = "https://" + PRIMARY_HOST + "/v1";
+ private static final String RED_HOST = "byok-red.invalid";
+ private static final String RED_BASE_URL = "https://" + RED_HOST + "/v1";
+ private static final String BLUE_HOST = "byok-blue.invalid";
+ private static final String BLUE_BASE_URL = "https://" + BLUE_HOST + "/v1";
+
+ private static E2ETestContext ctx;
+ private CapturingRequestHandler handler;
+
+ @BeforeAll
+ static void setup() throws Exception {
+ ctx = E2ETestContext.create();
+ }
+
+ @AfterAll
+ static void teardown() throws Exception {
+ if (ctx != null) {
+ ctx.close();
+ }
+ }
+
+ @BeforeEach
+ void resetHandler() {
+ handler = new CapturingRequestHandler();
+ }
+
+ @Test
+ void appliesCallbackTokenAsAuthorizationHeader() throws Exception {
+ String sentinel = "sentinel-bearer-token-abc123";
+ AtomicInteger calls = new AtomicInteger();
+ GetBearerToken getBearerToken = args -> {
+ calls.incrementAndGet();
+ return CompletableFuture.completedFuture(sentinel);
+ };
+
+ List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai")
+ .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken));
+ List models = List
+ .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o"));
+
+ runTurn(providers, models, "mi/default", "What is 5+5?");
+
+ assertTrue(handler.authHeaders().contains("Bearer " + sentinel),
+ "Expected captured Authorization headers to contain the callback token: " + handler.authHeaders());
+ assertTrue(calls.get() >= 1, "Expected the callback to be invoked at least once");
+ }
+
+ @Test
+ void reacquiresFreshTokenForEachRequest() throws Exception {
+ AtomicInteger calls = new AtomicInteger();
+ GetBearerToken getBearerToken = args -> CompletableFuture
+ .completedFuture("rotating-token-" + calls.incrementAndGet());
+
+ List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai")
+ .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken));
+ List models = List
+ .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o"));
+
+ runTurn(providers, models, "mi/default", "What is 1+1?");
+ runTurn(providers, models, "mi/default", "What is 2+2?");
+
+ List auths = handler.authHeaders();
+ assertTrue(auths.size() >= 2, "Expected at least two captured Authorization headers, got " + auths);
+ assertTrue(auths.get(0).startsWith("Bearer rotating-token-"), "Expected rotating token, got " + auths);
+ assertTrue(auths.get(1).startsWith("Bearer rotating-token-"), "Expected rotating token, got " + auths);
+ assertNotEquals(auths.get(0), auths.get(1), "Expected distinct tokens per request");
+ assertTrue(calls.get() >= 2, "Expected the callback to be invoked at least twice");
+ }
+
+ @Test
+ void dispatchesTokenAcquisitionPerProvider() throws Exception {
+ List acquiredFor = new ArrayList<>();
+ GetBearerToken redCallback = args -> {
+ assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded");
+ synchronized (acquiredFor) {
+ acquiredFor.add("red");
+ }
+ return CompletableFuture.completedFuture("token-for-red");
+ };
+ GetBearerToken blueCallback = args -> {
+ assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded");
+ synchronized (acquiredFor) {
+ acquiredFor.add("blue");
+ }
+ return CompletableFuture.completedFuture("token-for-blue");
+ };
+
+ List providers = List.of(
+ new NamedProviderConfig().setName("red").setType("openai").setWireApi("completions")
+ .setBaseUrl(RED_BASE_URL).setGetBearerToken(redCallback),
+ new NamedProviderConfig().setName("blue").setType("openai").setWireApi("completions")
+ .setBaseUrl(BLUE_BASE_URL).setGetBearerToken(blueCallback));
+ List models = List.of(
+ new ProviderModelConfig().setId("default").setProvider("red").setWireModel("byok-gpt-4o"),
+ new ProviderModelConfig().setId("default").setProvider("blue").setWireModel("byok-gpt-4o"));
+
+ runTurn(providers, models, "red/default", "What is 3+3?");
+ runTurn(providers, models, "blue/default", "What is 4+4?");
+
+ assertEquals("Bearer token-for-red", handler.authHeaderForHost(RED_HOST));
+ assertEquals("Bearer token-for-blue", handler.authHeaderForHost(BLUE_HOST));
+ synchronized (acquiredFor) {
+ assertTrue(acquiredFor.contains("red"), "Expected red provider to acquire a token");
+ assertTrue(acquiredFor.contains("blue"), "Expected blue provider to acquire a token");
+ }
+ }
+
+ private void runTurn(List providers, List models, String selectionId,
+ String prompt) throws Exception {
+ try (CopilotClient client = newLlmClient(ctx, handler)) {
+ CopilotSession session = client
+ .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)
+ .setModel(selectionId).setProviders(providers).setModels(models))
+ .get(60, TimeUnit.SECONDS);
+ try {
+ session.sendAndWait(new MessageOptions().setPrompt(prompt)).get(60, TimeUnit.SECONDS);
+ } catch (Exception ignored) {
+ // The fake BYOK endpoint returns 404 after capturing the token-bearing request.
+ } finally {
+ try {
+ session.close();
+ } catch (Exception ignored) {
+ // Ignore disconnect errors for the fake BYOK endpoint.
+ }
+ }
+ }
+ }
+
+ private static final class CapturingRequestHandler extends CopilotRequestHandler {
+
+ private final ConcurrentLinkedQueue captures = new ConcurrentLinkedQueue<>();
+
+ @Override
+ protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext rctx)
+ throws Exception {
+ String host = request.uri().getHost();
+ if (host != null && host.endsWith(".invalid")) {
+ captures.add(new CapturedRequest(request.uri().getHost(),
+ request.headers().firstValue("Authorization").orElse(null)));
+ return new StubHttpResponse(404, "{\"error\":{\"message\":\"fake byok endpoint\"}}");
+ }
+ return buildNonInferenceResponse(request.uri().toString());
+ }
+
+ List authHeaders() {
+ List auths = new ArrayList<>();
+ for (CapturedRequest capture : captures) {
+ if (capture.authorization() != null) {
+ auths.add(capture.authorization());
+ }
+ }
+ return auths;
+ }
+
+ String authHeaderForHost(String host) {
+ for (CapturedRequest capture : captures) {
+ if (host.equals(capture.host())) {
+ return capture.authorization();
+ }
+ }
+ return null;
+ }
+ }
+
+ private static final class StubHttpResponse implements HttpResponse {
+
+ private final int status;
+ private final HttpHeaders headers;
+ private final byte[] body;
+
+ StubHttpResponse(int status, String body) {
+ this.status = status;
+ this.body = body.getBytes(StandardCharsets.UTF_8);
+ this.headers = HttpHeaders.of(Map.of("content-type", List.of("application/json")), (k, v) -> true);
+ }
+
+ @Override
+ public int statusCode() {
+ return status;
+ }
+
+ @Override
+ public HttpRequest request() {
+ return null;
+ }
+
+ @Override
+ public Optional> previousResponse() {
+ return Optional.empty();
+ }
+
+ @Override
+ public HttpHeaders headers() {
+ return headers;
+ }
+
+ @Override
+ public InputStream body() {
+ return new ByteArrayInputStream(body);
+ }
+
+ @Override
+ public Optional sslSession() {
+ return Optional.empty();
+ }
+
+ @Override
+ public URI uri() {
+ return null;
+ }
+
+ @Override
+ public HttpClient.Version version() {
+ return HttpClient.Version.HTTP_1_1;
+ }
+ }
+
+ private record CapturedRequest(String host, String authorization) {
+ }
+}
diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts
index a1b403930..e0315b211 100644
--- a/nodejs/src/client.ts
+++ b/nodejs/src/client.ts
@@ -51,11 +51,14 @@ import type {
ExitPlanModeResult,
ForegroundSessionInfo,
GetAuthStatusResponse,
+ GetBearerToken,
GetStatusResponse,
InternalRuntimeConnection,
LargeToolOutputConfig,
MCPServerConfig,
ModelInfo,
+ NamedProviderConfig,
+ ProviderConfig,
ResumeSessionConfig,
SectionTransformFn,
SessionConfig,
@@ -154,6 +157,64 @@ function toJsonSchema(parameters: Tool["parameters"]): Record |
return parameters;
}
+/** Implicit provider name for the singular, whole-session {@link ProviderConfig}. */
+const DEFAULT_PROVIDER_NAME = "default";
+
+/** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */
+type WireProviderConfig = Omit & {
+ hasBearerTokenProvider?: boolean;
+};
+
+/** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */
+type WireNamedProviderConfig = Omit & {
+ hasBearerTokenProvider?: boolean;
+};
+
+/**
+ * Strips the non-serializable {@link GetBearerToken} callbacks from the singular
+ * and named provider configs before they cross the RPC boundary, replacing each
+ * with a `hasBearerTokenProvider: true` wire flag. The callback closes over its
+ * own token scope/audience, so nothing scope-related crosses the wire — the
+ * runtime only forwards the provider name back when it needs a token.
+ * Returns wire-safe provider configs alongside a map of provider name → callback
+ * for session-side registration.
+ */
+function extractBearerTokenProviders(
+ provider: ProviderConfig | undefined,
+ providers: NamedProviderConfig[] | undefined
+): {
+ wireProvider: WireProviderConfig | undefined;
+ wireProviders: WireNamedProviderConfig[] | undefined;
+ callbacks: Map;
+} {
+ const callbacks = new Map();
+
+ let wireProvider: WireProviderConfig | undefined = provider;
+ if (provider?.getBearerToken) {
+ const { getBearerToken, ...rest } = provider;
+ callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken);
+ wireProvider = {
+ ...rest,
+ hasBearerTokenProvider: true,
+ };
+ }
+
+ let wireProviders: WireNamedProviderConfig[] | undefined = providers;
+ if (providers?.some((p) => p.getBearerToken)) {
+ wireProviders = providers.map((p) => {
+ if (!p.getBearerToken) return p;
+ const { getBearerToken, ...rest } = p;
+ callbacks.set(p.name, getBearerToken);
+ return {
+ ...rest,
+ hasBearerTokenProvider: true,
+ };
+ });
+ }
+
+ return { wireProvider, wireProviders, callbacks };
+}
+
/**
* Convert MCP server configs from public API format (workingDirectory) to
* wire format (cwd) expected by the runtime.
@@ -1244,6 +1305,15 @@ export class CopilotClient {
const useServerGeneratedId = config.cloud != null && callerSessionId == null;
const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID());
+ // Strip non-serializable getBearerToken callbacks from provider configs,
+ // replacing them with a wire flag; keep the callbacks for session-side
+ // registration so the runtime can call back to acquire tokens.
+ const {
+ wireProvider: bearerWireProvider,
+ wireProviders: bearerWireProviders,
+ callbacks: bearerTokenCallbacks,
+ } = extractBearerTokenProviders(config.provider, config.providers);
+
// Extract transform callbacks from system message config before serialization.
const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks(
config.systemMessage
@@ -1261,6 +1331,9 @@ export class CopilotClient {
s.registerTools(config.tools);
s.registerCanvases(config.canvases);
s.registerCommands(config.commands);
+ if (bearerTokenCallbacks.size > 0) {
+ s.registerBearerTokenProviders(bearerTokenCallbacks);
+ }
s.registerPermissionHandler(config.onPermissionRequest);
if (config.onUserInputRequest) {
s.registerUserInputHandler(config.onUserInputRequest);
@@ -1332,9 +1405,9 @@ export class CopilotClient {
availableTools: toolFilterOptions.availableTools,
excludedTools: toolFilterOptions.excludedTools,
toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence,
- provider: config.provider,
+ provider: bearerWireProvider,
capi: config.capi,
- providers: config.providers,
+ providers: bearerWireProviders,
models: config.models,
enableSessionTelemetry: config.enableSessionTelemetry,
modelCapabilities: config.modelCapabilities,
@@ -1454,6 +1527,14 @@ export class CopilotClient {
session.registerTools(config.tools);
session.registerCanvases(config.canvases);
session.registerCommands(config.commands);
+ const {
+ wireProvider: bearerWireProvider,
+ wireProviders: bearerWireProviders,
+ callbacks: bearerTokenCallbacks,
+ } = extractBearerTokenProviders(config.provider, config.providers);
+ if (bearerTokenCallbacks.size > 0) {
+ session.registerBearerTokenProviders(bearerTokenCallbacks);
+ }
session.registerPermissionHandler(config.onPermissionRequest);
if (config.onUserInputRequest) {
session.registerUserInputHandler(config.onUserInputRequest);
@@ -1520,9 +1601,9 @@ export class CopilotClient {
name: cmd.name,
description: cmd.description,
})),
- provider: config.provider,
+ provider: bearerWireProvider,
capi: config.capi,
- providers: config.providers,
+ providers: bearerWireProviders,
models: config.models,
modelCapabilities: config.modelCapabilities,
largeOutput: toWireLargeOutput(config.largeOutput),
diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts
index 9bf02a32c..740a7bc89 100644
--- a/nodejs/src/index.ts
+++ b/nodejs/src/index.ts
@@ -84,6 +84,7 @@ export type {
MCPHTTPServerConfig,
MCPServerConfig,
DefaultAgentConfig,
+ GetBearerToken,
MessageOptions,
ModelBilling,
ModelBillingTokenPrices,
@@ -99,6 +100,7 @@ export type {
PermissionRequestResult,
ProviderConfig,
ProviderModelConfig,
+ ProviderTokenArgs,
RemoteSessionMode,
ResumeSessionConfig,
SectionOverride,
diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts
index 83effdef7..d87d2b9de 100644
--- a/nodejs/src/session.ts
+++ b/nodejs/src/session.ts
@@ -26,6 +26,7 @@ import type {
ExitPlanModeHandler,
ExitPlanModeRequest,
ExitPlanModeResult,
+ GetBearerToken,
UiInputOptions,
MessageOptions,
PermissionHandler,
@@ -120,6 +121,7 @@ export class CopilotSession {
new Map();
private toolHandlers: Map = new Map();
private canvases: Map = new Map();
+ private bearerTokenProviders: Map = new Map();
private commandHandlers: Map = new Map();
private permissionHandler?: PermissionHandler;
private userInputHandler?: UserInputHandler;
@@ -795,6 +797,45 @@ export class CopilotSession {
};
}
+ /**
+ * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers
+ * configured with managed-identity / on-demand bearer-token auth.
+ *
+ * The runtime never receives the callback itself; the SDK strips it from the
+ * provider config and instead sends `hasBearerTokenProvider: true`. When the
+ * runtime needs a token it issues a session-scoped `providerToken.getToken`
+ * request, which this handler routes to the matching per-provider callback.
+ *
+ * @param providers - Map of provider name → callback, or undefined/empty to clear.
+ * @internal This method is called internally when creating/resuming a session.
+ */
+ registerBearerTokenProviders(providers?: Map): void {
+ this.bearerTokenProviders.clear();
+ if (!providers || providers.size === 0) {
+ delete this.clientSessionApis.providerToken;
+ return;
+ }
+ for (const [name, callback] of providers) {
+ this.bearerTokenProviders.set(name, callback);
+ }
+
+ const self = this;
+ this.clientSessionApis.providerToken = {
+ async getToken(params) {
+ const callback = self.bearerTokenProviders.get(params.providerName);
+ if (!callback) {
+ throw new Error(
+ `No bearer-token provider registered for provider "${params.providerName}"`
+ );
+ }
+ const token = await callback({
+ providerName: params.providerName,
+ });
+ return { token };
+ },
+ };
+ }
+
/**
* Registers command handlers for this session.
*
diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts
index 1db91df94..61d9ca06d 100644
--- a/nodejs/src/types.ts
+++ b/nodejs/src/types.ts
@@ -2214,6 +2214,39 @@ export interface ResumeSessionConfig extends SessionConfigBase {
openCanvases?: OpenCanvasInstance[];
}
+/**
+ * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a
+ * fresh bearer token for a BYOK provider.
+ *
+ * @experimental Part of the experimental managed-identity / bearer-token-provider
+ * surface and may change or be removed in future SDK or CLI releases.
+ */
+export interface ProviderTokenArgs {
+ /**
+ * Name of the BYOK provider needing a token. For the singular, whole-session
+ * {@link ProviderConfig} this is the implicit provider name (`"default"`); for
+ * {@link NamedProviderConfig} entries it is {@link NamedProviderConfig.name}.
+ *
+ * The callback closes over its own token scope/audience; the runtime is
+ * provider-agnostic and forwards only the provider name.
+ */
+ providerName: string;
+}
+
+/**
+ * Per-provider callback that resolves a bearer token on demand, returning the
+ * raw token string (without the `Bearer ` prefix). The Copilot SDK itself takes
+ * no Azure dependency: the consumer supplies this callback backed by their own
+ * identity library (for example `@azure/identity`'s
+ * `DefaultAzureCredential.getToken(scope)`), and the runtime calls it once before
+ * each outbound model request. The runtime does no caching of its own, so the
+ * callback (or the identity library it wraps) owns token caching and refresh.
+ *
+ * @experimental Part of the experimental managed-identity / bearer-token-provider
+ * surface and may change or be removed in future SDK or CLI releases.
+ */
+export type GetBearerToken = (args: ProviderTokenArgs) => Promise;
+
/**
* Configuration for a custom API provider.
*/
@@ -2256,6 +2289,18 @@ export interface ProviderConfig {
*/
bearerToken?: string;
+ /**
+ * Per-request bearer-token provider for managed-identity / on-demand auth.
+ * When set, the SDK keeps this function client-side (it is never serialized)
+ * and the runtime calls back into this client to acquire a token before each
+ * outbound request. The runtime does no caching of its own, so the callback
+ * owns token caching and refresh. Mutually exclusive with {@link apiKey} /
+ * {@link bearerToken}.
+ *
+ * @experimental
+ */
+ getBearerToken?: GetBearerToken;
+
/**
* Azure-specific options
*/
@@ -2347,6 +2392,18 @@ export interface NamedProviderConfig {
*/
bearerToken?: string;
+ /**
+ * Per-request bearer-token provider for managed-identity / on-demand auth.
+ * When set, the SDK keeps this function client-side (it is never serialized)
+ * and the runtime calls back into this client to acquire a token before each
+ * outbound request. The runtime does no caching of its own, so the callback
+ * owns token caching and refresh. Mutually exclusive with {@link apiKey} /
+ * {@link bearerToken}.
+ *
+ * @experimental
+ */
+ getBearerToken?: GetBearerToken;
+
/**
* Azure-specific options.
*/
diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts
new file mode 100644
index 000000000..228b7a022
--- /dev/null
+++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts
@@ -0,0 +1,255 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ *--------------------------------------------------------------------------------------------*/
+
+import { beforeEach, describe, expect, it } from "vitest";
+import { approveAll, CopilotRequestHandler } from "../../src/index.js";
+import type {
+ CopilotRequestContext,
+ GetBearerToken,
+ NamedProviderConfig,
+ ProviderModelConfig,
+} from "../../src/index.js";
+import { createSdkTestContext } from "./harness/sdkTestContext.js";
+
+/**
+ * A captured outbound HTTP request the runtime aimed at a fake BYOK provider
+ * endpoint: just the host and the `Authorization` header, which is all these
+ * tests need to assert on.
+ */
+interface CapturedRequest {
+ host: string;
+ authorization?: string;
+}
+
+// Fake BYOK provider base URLs. These hosts are never actually dialed: the
+// client-global request interceptor fully answers any request aimed at a
+// `.invalid` host, so they only need to be syntactically valid, non-resolving
+// URLs. Distinct hosts let the per-provider test assert routing by host.
+const PRIMARY_HOST = "byok-endpoint.invalid";
+const PRIMARY_BASE_URL = `https://${PRIMARY_HOST}/v1`;
+const RED_HOST = "byok-red.invalid";
+const RED_BASE_URL = `https://${RED_HOST}/v1`;
+const BLUE_HOST = "byok-blue.invalid";
+const BLUE_BASE_URL = `https://${BLUE_HOST}/v1`;
+
+/**
+ * Client-global HTTP request interceptor (from the SDK's `CopilotRequestHandler`
+ * surface) used in place of a real HTTP listener.
+ *
+ * The runtime invokes {@link sendRequest} for every model-layer HTTP request it
+ * would otherwise issue. We capture the ones aimed at a fake BYOK host —
+ * recording the `Authorization` header the runtime applied after calling the
+ * provider's `getBearerToken` callback over the session-scoped
+ * `providerToken.getToken` RPC — and answer them with a synthetic `404` (a
+ * non-retryable status, so each outbound model request yields exactly one
+ * capture). Every other request (CAPI bootstrap: model catalog, policy, …) is
+ * passed straight through to the real network via `super.sendRequest`.
+ *
+ * Because the handler is client-global (one per CLI process), it is installed
+ * once for the whole fixture and {@link reset} between tests.
+ */
+class CapturingRequestHandler extends CopilotRequestHandler {
+ public readonly captures: CapturedRequest[] = [];
+
+ protected override async sendRequest(
+ request: Request,
+ ctx: CopilotRequestContext
+ ): Promise {
+ const url = new URL(request.url);
+ if (url.hostname.endsWith(".invalid")) {
+ this.captures.push({
+ host: url.host,
+ authorization: request.headers.get("authorization") ?? undefined,
+ });
+ return new Response(JSON.stringify({ error: { message: "fake byok endpoint" } }), {
+ status: 404,
+ headers: { "content-type": "application/json" },
+ });
+ }
+ return super.sendRequest(request, ctx);
+ }
+
+ reset(): void {
+ this.captures.length = 0;
+ }
+
+ /** The `Authorization` headers captured across BYOK requests, in arrival order. */
+ authHeaders(): string[] {
+ return this.captures
+ .map((c) => c.authorization)
+ .filter((v): v is string => typeof v === "string");
+ }
+
+ /** The `Authorization` header captured for requests aimed at `host`, if any. */
+ authHeaderForHost(host: string): string | undefined {
+ return this.captures.find((c) => c.host === host)?.authorization;
+ }
+}
+
+/**
+ * End-to-end coverage for the experimental BYOK bearer-token-provider surface
+ * (`getBearerToken` on a provider config). The callback stays entirely on the
+ * SDK/client side: the SDK strips it from the wire config, sets the
+ * `hasBearerTokenProvider` flag, and the runtime calls back over the session-scoped
+ * `providerToken.getToken` RPC before each outbound model request, applying the
+ * returned token as the `Authorization` header.
+ *
+ * Rather than standing up a real HTTP listener, these tests install a
+ * client-global {@link CapturingRequestHandler} that intercepts the runtime's
+ * outbound model request in-process, captures the `Authorization` header, and
+ * returns a synthetic response. They validate, against a real runtime:
+ * 1. the callback's token reaches the model request as `Authorization: Bearer `;
+ * 2. the runtime re-acquires a token per request (no runtime-side caching);
+ * 3. per-provider dispatch routes each provider's turn to its own callback,
+ * and the resulting token reaches that provider's endpoint.
+ */
+describe("BYOK bearer-token provider", async () => {
+ const handler = new CapturingRequestHandler();
+ const { copilotClient: client } = await createSdkTestContext({
+ copilotClientOptions: { requestHandler: handler },
+ });
+
+ beforeEach(() => {
+ handler.reset();
+ });
+
+ /** Drive one BYOK turn; the synthetic 404 errors the turn, which is expected. */
+ async function runTurn(
+ providers: NamedProviderConfig[],
+ models: ProviderModelConfig[],
+ selectionId: string,
+ prompt: string
+ ): Promise {
+ const session = await client.createSession({
+ onPermissionRequest: approveAll,
+ model: selectionId,
+ providers,
+ models,
+ });
+ try {
+ // The interceptor always 404s, so the turn errors after the runtime
+ // has already sent the (token-bearing) request — which is all we
+ // assert on. Swallow the resulting error.
+ await session.sendAndWait({ prompt }).catch(() => undefined);
+ } finally {
+ try {
+ await session.disconnect();
+ } catch {
+ // ignore disconnect errors for the fake BYOK endpoint
+ }
+ }
+ }
+
+ it("applies the callback's token as the Authorization header", async () => {
+ const SENTINEL = "sentinel-bearer-token-abc123";
+ let calls = 0;
+ const getBearerToken: GetBearerToken = async () => {
+ calls += 1;
+ return SENTINEL;
+ };
+
+ const providers: NamedProviderConfig[] = [
+ {
+ name: "mi",
+ type: "openai",
+ wireApi: "completions",
+ baseUrl: PRIMARY_BASE_URL,
+ getBearerToken,
+ },
+ ];
+ const models: ProviderModelConfig[] = [
+ { id: "default", provider: "mi", wireModel: "byok-gpt-4o" },
+ ];
+
+ await runTurn(providers, models, "mi/default", "What is 5+5?");
+
+ // The runtime acquired a token via the callback and applied it verbatim as
+ // the bearer credential on the outbound model request.
+ expect(handler.authHeaders()).toContain(`Bearer ${SENTINEL}`);
+ expect(calls).toBeGreaterThanOrEqual(1);
+ });
+
+ it("re-acquires a fresh token for each request (no runtime caching)", async () => {
+ let calls = 0;
+ const getBearerToken: GetBearerToken = async () => {
+ calls += 1;
+ // A distinct token per acquisition proves the runtime re-invokes the
+ // callback per request rather than caching a previous token.
+ return `rotating-token-${calls}`;
+ };
+
+ const providers: NamedProviderConfig[] = [
+ {
+ name: "mi",
+ type: "openai",
+ wireApi: "completions",
+ baseUrl: PRIMARY_BASE_URL,
+ getBearerToken,
+ },
+ ];
+ const models: ProviderModelConfig[] = [
+ { id: "default", provider: "mi", wireModel: "byok-gpt-4o" },
+ ];
+
+ await runTurn(providers, models, "mi/default", "What is 1+1?");
+ await runTurn(providers, models, "mi/default", "What is 2+2?");
+
+ // Each outbound request carries a freshly-acquired, distinct token.
+ const auths = handler.authHeaders();
+ expect(auths.length).toBeGreaterThanOrEqual(2);
+ expect(auths[0]).toMatch(/^Bearer rotating-token-\d+$/);
+ expect(auths[1]).toMatch(/^Bearer rotating-token-\d+$/);
+ expect(auths[0]).not.toBe(auths[1]);
+ expect(calls).toBeGreaterThanOrEqual(2);
+ });
+
+ it("dispatches token acquisition per provider", async () => {
+ const tokenByProvider: Record = {
+ red: "token-for-red",
+ blue: "token-for-blue",
+ };
+ const acquiredFor: string[] = [];
+ const makeCallback =
+ (providerName: string): GetBearerToken =>
+ async (args) => {
+ // The runtime forwards the requesting provider's name so the client
+ // can dispatch to the right credential.
+ expect(args.providerName).toBe(providerName);
+ acquiredFor.push(providerName);
+ return tokenByProvider[providerName];
+ };
+
+ const providers: NamedProviderConfig[] = [
+ {
+ name: "red",
+ type: "openai",
+ wireApi: "completions",
+ baseUrl: RED_BASE_URL,
+ getBearerToken: makeCallback("red"),
+ },
+ {
+ name: "blue",
+ type: "openai",
+ wireApi: "completions",
+ baseUrl: BLUE_BASE_URL,
+ getBearerToken: makeCallback("blue"),
+ },
+ ];
+ const models: ProviderModelConfig[] = [
+ { id: "default", provider: "red", wireModel: "byok-gpt-4o" },
+ { id: "default", provider: "blue", wireModel: "byok-gpt-4o" },
+ ];
+
+ await runTurn(providers, models, "red/default", "What is 3+3?");
+ await runTurn(providers, models, "blue/default", "What is 4+4?");
+
+ // Each provider's turn was authenticated with its own token AND that token
+ // was delivered to that provider's endpoint, proving per-provider dispatch
+ // (not a single session-global credential).
+ expect(handler.authHeaderForHost(RED_HOST)).toBe(`Bearer ${tokenByProvider.red}`);
+ expect(handler.authHeaderForHost(BLUE_HOST)).toBe(`Bearer ${tokenByProvider.blue}`);
+ expect(acquiredFor).toContain("red");
+ expect(acquiredFor).toContain("blue");
+ });
+});
diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py
index 06ecf4188..1e7a3afb1 100644
--- a/python/copilot/__init__.py
+++ b/python/copilot/__init__.py
@@ -100,6 +100,7 @@
ExitPlanModeHandler,
ExitPlanModeRequest,
ExitPlanModeResult,
+ GetBearerToken,
InfiniteSessionConfig,
InputOptions,
LargeToolOutputConfig,
@@ -128,6 +129,7 @@
PreToolUseHookOutput,
ProviderConfig,
ProviderModelConfig,
+ ProviderTokenArgs,
ReasoningSummary,
SessionCapabilities,
SessionEndHandler,
@@ -214,6 +216,7 @@
"ExtensionInfo",
"CopilotWebSocketForwarder",
"GetAuthStatusResponse",
+ "GetBearerToken",
"GetStatusResponse",
"InfiniteSessionConfig",
"InputOptions",
@@ -257,6 +260,7 @@
"PreToolUseHookOutput",
"ProviderConfig",
"ProviderModelConfig",
+ "ProviderTokenArgs",
"ReasoningSummary",
"RemoteSessionMode",
"RuntimeConnection",
diff --git a/python/copilot/client.py b/python/copilot/client.py
index 4e32a2983..ebfdcf992 100644
--- a/python/copilot/client.py
+++ b/python/copilot/client.py
@@ -89,6 +89,7 @@
DefaultAgentConfig,
ElicitationHandler,
ExitPlanModeHandler,
+ GetBearerToken,
InfiniteSessionConfig,
LargeToolOutputConfig,
MCPServerConfig,
@@ -171,6 +172,36 @@ def _capi_session_options_to_wire(options: CapiSessionOptions) -> dict[str, Any]
return wire
+# Implicit provider name for the singular, whole-session ``provider`` config.
+# Named providers are keyed by their own ``name``.
+_DEFAULT_BEARER_TOKEN_PROVIDER_NAME = "default"
+
+
+def _collect_bearer_token_callbacks(
+ provider: ProviderConfig | None,
+ providers: list[NamedProviderConfig] | None,
+) -> dict[str, GetBearerToken]:
+ """Collect per-provider ``get_bearer_token`` callbacks keyed by provider name.
+
+ The singular, whole-session ``provider`` uses the implicit
+ ``_DEFAULT_BEARER_TOKEN_PROVIDER_NAME``; ``providers`` entries use their own
+ ``name``. The callbacks are never serialized — the wire conversion emits
+ ``hasBearerTokenProvider: true`` instead and the runtime calls back over
+ ``providerToken.getToken``.
+ """
+ callbacks: dict[str, GetBearerToken] = {}
+ if provider is not None:
+ singular = provider.get("get_bearer_token")
+ if singular is not None:
+ callbacks[_DEFAULT_BEARER_TOKEN_PROVIDER_NAME] = singular
+ if providers:
+ for named in providers:
+ callback = named.get("get_bearer_token")
+ if callback is not None:
+ callbacks[named["name"]] = callback
+ return callbacks
+
+
def _validate_session_fs_config(config: SessionFsConfig) -> None:
if not config.get("initial_working_directory"):
raise ValueError("session_fs.initial_working_directory is required")
@@ -2128,6 +2159,7 @@ def _initialize_session(sid: str) -> CopilotSession:
s._register_auto_mode_switch_handler(on_auto_mode_switch_request)
if canvas_handler is not None:
s._register_canvas_handler(canvas_handler)
+ s._register_bearer_token_providers(_collect_bearer_token_callbacks(provider, providers))
if hooks:
s._register_hooks(hooks)
if transform_callbacks:
@@ -2701,6 +2733,9 @@ async def resume_session(
session._register_auto_mode_switch_handler(on_auto_mode_switch_request)
if canvas_handler is not None:
session._register_canvas_handler(canvas_handler)
+ session._register_bearer_token_providers(
+ _collect_bearer_token_callbacks(provider, providers)
+ )
if hooks:
session._register_hooks(hooks)
if transform_callbacks:
@@ -3231,6 +3266,8 @@ def _convert_provider_to_wire_format(
wire_provider["transport"] = provider["transport"]
if "bearer_token" in provider:
wire_provider["bearerToken"] = provider["bearer_token"]
+ if provider.get("get_bearer_token") is not None:
+ wire_provider["hasBearerTokenProvider"] = True
if "headers" in provider:
wire_provider["headers"] = provider["headers"]
if "model_id" in provider:
@@ -3267,6 +3304,8 @@ def _convert_named_provider_to_wire_format(
wire["apiKey"] = provider["api_key"]
if "bearer_token" in provider:
wire["bearerToken"] = provider["bearer_token"]
+ if provider.get("get_bearer_token") is not None:
+ wire["hasBearerTokenProvider"] = True
if "headers" in provider:
wire["headers"] = provider["headers"]
if "azure" in provider:
diff --git a/python/copilot/session.py b/python/copilot/session.py
index b4c01b885..94fba994a 100644
--- a/python/copilot/session.py
+++ b/python/copilot/session.py
@@ -44,6 +44,8 @@
PermissionDecisionApproveOnce,
PermissionDecisionRequest,
PermissionDecisionUserNotAvailable,
+ ProviderTokenAcquireRequest,
+ ProviderTokenAcquireResult,
SessionLogLevel,
SessionRpc,
UIElicitationRequest,
@@ -1077,6 +1079,29 @@ class AzureProviderOptions(TypedDict, total=False):
api_version: str # Azure API version. Defaults to "2024-10-21".
+class ProviderTokenArgs(TypedDict):
+ """Arguments passed to a :data:`GetBearerToken` callback when the runtime
+ needs a fresh bearer token for a BYOK provider.
+
+ **Experimental.** Part of the bearer-token-provider surface and may change or
+ be removed in future SDK or CLI releases.
+ """
+
+ # Name of the BYOK provider needing a token. For the singular, whole-session
+ # ``provider`` this is the implicit provider name ("default"); for
+ # ``NamedProviderConfig`` entries it is ``NamedProviderConfig.name``.
+ provider_name: str
+
+
+# Per-request callback that resolves a bearer token on demand for a BYOK
+# provider (for example via Azure Managed Identity). The Copilot SDK takes no
+# identity dependency: supply a callback backed by your own identity library.
+# Never serialized — setting it makes the SDK send ``hasBearerTokenProvider`` on
+# the wire and answer the runtime's ``providerToken.getToken`` requests. May be
+# sync or async.
+GetBearerToken = Callable[[ProviderTokenArgs], str | Awaitable[str]]
+
+
class ProviderConfig(TypedDict, total=False):
"""Configuration for a custom API provider"""
@@ -1113,6 +1138,12 @@ class ProviderConfig(TypedDict, total=False):
# Overrides the resolved model's default max output tokens. When hit, the
# model stops generating and returns a truncated response.
max_output_tokens: int
+ # Per-request callback that resolves a bearer token on demand for this BYOK
+ # provider (for example via Azure Managed Identity). Never serialized — the
+ # SDK sends hasBearerTokenProvider: true on the wire and answers the
+ # runtime's providerToken.getToken requests with this callback's result.
+ # Mutually exclusive with api_key and bearer_token.
+ get_bearer_token: GetBearerToken
class NamedProviderConfig(TypedDict, total=False):
@@ -1139,6 +1170,11 @@ class NamedProviderConfig(TypedDict, total=False):
bearer_token: str
azure: AzureProviderOptions # Azure-specific options
headers: dict[str, str]
+ # Per-request bearer-token callback for this named BYOK provider. Never
+ # serialized; the SDK sends hasBearerTokenProvider: true and answers the
+ # runtime's providerToken.getToken requests. Mutually exclusive with api_key
+ # and bearer_token.
+ get_bearer_token: GetBearerToken
class ProviderModelConfig(TypedDict, total=False):
@@ -1210,6 +1246,35 @@ def _canvas_handler_error(err: Exception) -> JsonRpcError:
)
+class _BearerTokenProviderAdapter:
+ """Routes runtime ``providerToken.getToken`` requests to the matching
+ per-provider :data:`GetBearerToken` callback registered on the session.
+
+ The runtime calls this once per outbound request for a BYOK provider that
+ declared ``hasBearerTokenProvider: true``; it does no caching, so the SDK
+ consumer's callback (typically backed by an identity library) owns
+ acquisition, caching, and refresh.
+ """
+
+ def __init__(self, session: CopilotSession) -> None:
+ self._session = session
+
+ async def get_token(self, params: ProviderTokenAcquireRequest) -> ProviderTokenAcquireResult:
+ provider_name = params.provider_name
+ with self._session._bearer_token_providers_lock:
+ callback = self._session._bearer_token_providers.get(provider_name)
+ if callback is None:
+ raise JsonRpcError(
+ -32603,
+ f"No bearer-token provider registered for provider: {provider_name!r}",
+ )
+ args: ProviderTokenArgs = {"provider_name": provider_name}
+ result = callback(args)
+ if inspect.isawaitable(result):
+ result = await result
+ return ProviderTokenAcquireResult(token=cast(str, result))
+
+
class CopilotSession:
"""
Represents a single conversation session with the Copilot CLI.
@@ -1275,6 +1340,8 @@ def __init__(
self._transform_callbacks_lock = threading.Lock()
self._command_handlers: dict[str, CommandHandler] = {}
self._command_handlers_lock = threading.Lock()
+ self._bearer_token_providers: dict[str, GetBearerToken] = {}
+ self._bearer_token_providers_lock = threading.Lock()
self._elicitation_handler: ElicitationHandler | None = None
self._elicitation_handler_lock = threading.Lock()
self._capabilities: SessionCapabilities = {}
@@ -2015,6 +2082,26 @@ def _register_commands(self, commands: list[CommandDefinition] | None) -> None:
for cmd in commands:
self._command_handlers[cmd.name] = cmd.handler
+ def _register_bearer_token_providers(self, providers: dict[str, GetBearerToken] | None) -> None:
+ """Register per-provider bearer-token callbacks for this session.
+
+ The runtime never receives the callbacks themselves; the SDK strips them
+ from the provider config and instead sends ``hasBearerTokenProvider:
+ true``. When the runtime needs a token it issues a session-scoped
+ ``providerToken.getToken`` request, which the registered handler routes
+ to the matching per-provider callback.
+
+ Args:
+ providers: Map of provider name -> callback, or None/empty to clear.
+ """
+ with self._bearer_token_providers_lock:
+ self._bearer_token_providers.clear()
+ if not providers:
+ self._client_session_apis.provider_token = None
+ return
+ self._bearer_token_providers.update(providers)
+ self._client_session_apis.provider_token = _BearerTokenProviderAdapter(self)
+
def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> None:
"""Register the elicitation handler for this session.
diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py
new file mode 100644
index 000000000..28f9e0586
--- /dev/null
+++ b/python/e2e/test_byok_bearer_token_provider_e2e.py
@@ -0,0 +1,251 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# --------------------------------------------------------------------------------------------
+
+"""E2E coverage for the experimental BYOK bearer-token-provider surface.
+
+Mirrors ``nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts``. A BYOK
+provider config may carry a ``get_bearer_token`` callback; the callback stays
+entirely on the SDK/client side. The SDK strips it from the wire config, sets
+the ``hasBearerTokenProvider`` flag, and the runtime calls back over the
+session-scoped ``providerToken.getToken`` RPC before each outbound model
+request, applying the returned token as the ``Authorization`` header.
+
+Like the other ``copilot_request_*`` tests, this one installs a client-global
+``CopilotRequestHandler`` instead of using the CAPI proxy: the handler
+fabricates the bootstrap (catalog/policy) responses and intercepts the
+runtime's outbound BYOK request in-process, capturing the ``Authorization``
+header and returning a synthetic ``404``. It validates, against a real runtime:
+ 1. the callback's token reaches the model request as ``Authorization: Bearer ``;
+ 2. the runtime re-acquires a token per request (no runtime-side caching);
+ 3. per-provider dispatch routes each provider's turn to its own callback, and
+ the resulting token reaches that provider's endpoint.
+"""
+
+from __future__ import annotations
+
+import re
+
+import httpx
+import pytest
+import pytest_asyncio
+
+from copilot import CopilotRequestContext, CopilotRequestHandler
+from copilot.session import GetBearerToken, PermissionHandler
+
+from ._copilot_request_helpers import build_isolated_client, build_non_inference_response
+from .testharness import E2ETestContext
+
+pytestmark = pytest.mark.asyncio(loop_scope="module")
+
+# Fake BYOK provider base URLs. These hosts are never actually dialed: the
+# client-global request interceptor fully answers any request aimed at a
+# ``.invalid`` host, so they only need to be syntactically valid, non-resolving
+# URLs. Distinct hosts let the per-provider test assert routing by host.
+PRIMARY_HOST = "byok-endpoint.invalid"
+PRIMARY_BASE_URL = f"https://{PRIMARY_HOST}/v1"
+RED_HOST = "byok-red.invalid"
+RED_BASE_URL = f"https://{RED_HOST}/v1"
+BLUE_HOST = "byok-blue.invalid"
+BLUE_BASE_URL = f"https://{BLUE_HOST}/v1"
+
+
+class _CapturingRequestHandler(CopilotRequestHandler):
+ """Client-global HTTP interceptor used in place of a real BYOK listener.
+
+ The runtime invokes :meth:`send_request` for every model-layer HTTP request.
+ Requests aimed at a fake BYOK host are captured — recording the
+ ``Authorization`` header the runtime applied after calling the provider's
+ ``get_bearer_token`` callback over ``providerToken.getToken`` — and answered
+ with a synthetic ``404`` (non-retryable, so each outbound model request
+ yields exactly one capture). Every other request (CAPI bootstrap: model
+ catalog, policy, …) is fabricated locally so no real network or CAPI proxy
+ is involved.
+ """
+
+ def __init__(self) -> None:
+ # (host, authorization) for each captured BYOK request, in arrival order.
+ self.captures: list[tuple[str, str | None]] = []
+
+ async def send_request(
+ self, request: httpx.Request, ctx: CopilotRequestContext
+ ) -> httpx.Response:
+ url = httpx.URL(request.url)
+ host = url.host
+ if host.endswith(".invalid"):
+ self.captures.append((host, request.headers.get("authorization")))
+ return httpx.Response(
+ 404,
+ headers={"content-type": "application/json"},
+ json={"error": {"message": "fake byok endpoint"}},
+ request=request,
+ )
+ return build_non_inference_response(str(request.url))
+
+ def reset(self) -> None:
+ self.captures.clear()
+
+ def auth_headers(self) -> list[str]:
+ """The ``Authorization`` headers captured across BYOK requests, in order."""
+ return [auth for (_host, auth) in self.captures if auth is not None]
+
+ def auth_header_for_host(self, host: str) -> str | None:
+ """The ``Authorization`` header captured for requests aimed at ``host``."""
+ for captured_host, auth in self.captures:
+ if captured_host == host:
+ return auth
+ return None
+
+
+@pytest_asyncio.fixture(loop_scope="module")
+async def bearer_fixture(ctx: E2ETestContext):
+ handler = _CapturingRequestHandler()
+ client = build_isolated_client(ctx, handler)
+ await client.start()
+ try:
+ yield client, handler
+ finally:
+ try:
+ await client.stop()
+ except Exception:
+ # Best-effort teardown during fixture cleanup.
+ pass
+
+
+async def _run_turn(client, providers, models, selection_id: str, prompt: str) -> None:
+ """Drive one BYOK turn; the synthetic 404 errors the turn, which is expected."""
+ session = await client.create_session(
+ on_permission_request=PermissionHandler.approve_all,
+ model=selection_id,
+ providers=providers,
+ models=models,
+ )
+ try:
+ # The interceptor always 404s, so the turn errors after the runtime has
+ # already sent the (token-bearing) request — which is all we assert on.
+ try:
+ await session.send_and_wait(prompt)
+ except Exception:
+ pass
+ finally:
+ try:
+ await session.disconnect()
+ except Exception:
+ # ignore disconnect errors for the fake BYOK endpoint
+ pass
+
+
+class TestByokBearerTokenProvider:
+ async def test_applies_the_callbacks_token_as_the_authorization_header(self, bearer_fixture):
+ client, handler = bearer_fixture
+ handler.reset()
+
+ sentinel = "sentinel-bearer-token-abc123"
+ calls = 0
+
+ async def get_bearer_token(args) -> str:
+ nonlocal calls
+ calls += 1
+ return sentinel
+
+ providers = [
+ {
+ "name": "mi",
+ "type": "openai",
+ "wire_api": "completions",
+ "base_url": PRIMARY_BASE_URL,
+ "get_bearer_token": get_bearer_token,
+ }
+ ]
+ models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}]
+
+ await _run_turn(client, providers, models, "mi/default", "What is 5+5?")
+
+ # The runtime acquired a token via the callback and applied it verbatim
+ # as the bearer credential on the outbound model request.
+ assert f"Bearer {sentinel}" in handler.auth_headers()
+ assert calls >= 1
+
+ async def test_reacquires_a_fresh_token_for_each_request(self, bearer_fixture):
+ client, handler = bearer_fixture
+ handler.reset()
+
+ calls = 0
+
+ async def get_bearer_token(args) -> str:
+ nonlocal calls
+ calls += 1
+ # A distinct token per acquisition proves the runtime re-invokes the
+ # callback per request rather than caching a previous token.
+ return f"rotating-token-{calls}"
+
+ providers = [
+ {
+ "name": "mi",
+ "type": "openai",
+ "wire_api": "completions",
+ "base_url": PRIMARY_BASE_URL,
+ "get_bearer_token": get_bearer_token,
+ }
+ ]
+ models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}]
+
+ await _run_turn(client, providers, models, "mi/default", "What is 1+1?")
+ await _run_turn(client, providers, models, "mi/default", "What is 2+2?")
+
+ # Each outbound request carries a freshly-acquired, distinct token.
+ auths = handler.auth_headers()
+ assert len(auths) >= 2
+ assert re.match(r"^Bearer rotating-token-\d+$", auths[0])
+ assert re.match(r"^Bearer rotating-token-\d+$", auths[1])
+ assert auths[0] != auths[1]
+ assert calls >= 2
+
+ async def test_dispatches_token_acquisition_per_provider(self, bearer_fixture):
+ client, handler = bearer_fixture
+ handler.reset()
+
+ token_by_provider = {"red": "token-for-red", "blue": "token-for-blue"}
+ acquired_for: list[str] = []
+
+ def make_callback(provider_name: str) -> GetBearerToken:
+ async def callback(args) -> str:
+ # The runtime forwards the requesting provider's name so the
+ # client can dispatch to the right credential.
+ assert args["provider_name"] == provider_name
+ acquired_for.append(provider_name)
+ return token_by_provider[provider_name]
+
+ return callback
+
+ providers = [
+ {
+ "name": "red",
+ "type": "openai",
+ "wire_api": "completions",
+ "base_url": RED_BASE_URL,
+ "get_bearer_token": make_callback("red"),
+ },
+ {
+ "name": "blue",
+ "type": "openai",
+ "wire_api": "completions",
+ "base_url": BLUE_BASE_URL,
+ "get_bearer_token": make_callback("blue"),
+ },
+ ]
+ models = [
+ {"id": "default", "provider": "red", "wire_model": "byok-gpt-4o"},
+ {"id": "default", "provider": "blue", "wire_model": "byok-gpt-4o"},
+ ]
+
+ await _run_turn(client, providers, models, "red/default", "What is 3+3?")
+ await _run_turn(client, providers, models, "blue/default", "What is 4+4?")
+
+ # Each provider's turn was authenticated with its own token AND that
+ # token was delivered to that provider's endpoint, proving per-provider
+ # dispatch (not a single session-global credential).
+ assert handler.auth_header_for_host(RED_HOST) == f"Bearer {token_by_provider['red']}"
+ assert handler.auth_header_for_host(BLUE_HOST) == f"Bearer {token_by_provider['blue']}"
+ assert "red" in acquired_for
+ assert "blue" in acquired_for
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 018907d99..22fdc53d7 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -22,6 +22,9 @@ pub mod hooks;
mod jsonrpc;
/// Permission-policy helpers that produce a [`handler::PermissionHandler`].
pub mod permission;
+/// BYOK bearer-token provider callbacks.
+pub mod provider_token;
+mod provider_token_dispatch;
/// GitHub Copilot CLI binary resolution (env var, embedded, dev cache).
pub(crate) mod resolve;
mod router;
@@ -72,6 +75,7 @@ pub(crate) use jsonrpc::{
JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes,
};
pub use mode::{BUILTIN_TOOLS_ISOLATED, ClientMode, ToolSet};
+pub use provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs};
/// Re-exported JSON-RPC internals for integration tests (requires `test-support` feature).
#[cfg(feature = "test-support")]
diff --git a/rust/src/provider_token.rs b/rust/src/provider_token.rs
new file mode 100644
index 000000000..f92715006
--- /dev/null
+++ b/rust/src/provider_token.rs
@@ -0,0 +1,105 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ *--------------------------------------------------------------------------------------------*/
+
+//! BYOK bearer-token provider callbacks.
+//!
+//!
+//!
+//! **Experimental.** These types are part of an experimental wire-protocol
+//! surface and may change or be removed in future SDK or CLI releases.
+//!
+//!
+
+use std::future::Future;
+
+use async_trait::async_trait;
+
+/// Arguments passed to a BYOK bearer-token provider callback.
+///
+///
+///
+/// **Experimental.** This type is part of an experimental wire-protocol
+/// surface and may change or be removed in future SDK or CLI releases.
+///
+///
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ProviderTokenArgs {
+ /// Name of the BYOK provider needing a token.
+ ///
+ /// This is `"default"` for the singular whole-session provider, otherwise
+ /// the named provider's `name`.
+ pub provider_name: String,
+}
+
+/// Error returned by a [`BearerTokenProvider`].
+///
+///
+///
+/// **Experimental.** This type is part of an experimental wire-protocol
+/// surface and may change or be removed in future SDK or CLI releases.
+///
+///
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct BearerTokenError {
+ message: String,
+}
+
+impl BearerTokenError {
+ /// Construct a bearer-token error with a human-readable message.
+ pub fn message(message: impl Into) -> Self {
+ Self {
+ message: message.into(),
+ }
+ }
+
+ /// Return the human-readable error message.
+ pub fn as_str(&self) -> &str {
+ &self.message
+ }
+}
+
+impl std::fmt::Display for BearerTokenError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_str(&self.message)
+ }
+}
+
+impl std::error::Error for BearerTokenError {}
+
+impl From for BearerTokenError {
+ fn from(message: String) -> Self {
+ Self::message(message)
+ }
+}
+
+impl From<&str> for BearerTokenError {
+ fn from(message: &str) -> Self {
+ Self::message(message)
+ }
+}
+
+/// Provider-side callback used to acquire bearer tokens for BYOK providers.
+///
+///
+///
+/// **Experimental.** This trait is part of an experimental wire-protocol
+/// surface and may change or be removed in future SDK or CLI releases.
+///
+///
+#[async_trait]
+pub trait BearerTokenProvider: Send + Sync {
+ /// Acquire a bearer token without the `Bearer ` prefix.
+ async fn get_token(&self, args: ProviderTokenArgs) -> Result;
+}
+
+#[async_trait]
+impl BearerTokenProvider for F
+where
+ F: Fn(ProviderTokenArgs) -> Fut + Send + Sync,
+ Fut: Future