From a7c6c0610cba1a11bebe796fea7a2859ee1c1d0e Mon Sep 17 00:00:00 2001 From: Maxence Dominici Date: Sat, 28 Mar 2026 20:22:39 +0100 Subject: [PATCH 1/2] feat: support explicit OAuth config for remote MCP servers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add an `oauth` block to the `remote` toolset config allowing users to supply pre-registered OAuth credentials (`clientId`, `clientSecret`, `callbackPort`, `scopes`, `tls`) for MCP servers that do not support Dynamic Client Registration (RFC 7591), such as Slack and GitHub MCP Server. Previously the OAuth flow fabricated a `registration_endpoint` by appending `/register` to the auth server URL when the server metadata omitted it. This caused a guaranteed 404/302 failure and silently dropped the toolset. The fabrication is removed entirely — servers that omit `registration_endpoint` now receive a clear error directing the user to set `remote.oauth.clientId`. Additional fixes included in this change: - OAuth toolsets are deferred at startup and only started once RunStream is active and an elicitation handler is available, preventing the `no elicitation handler configured` failure - User decline of the OAuth prompt is permanent for the session: a `UserDeclined` sentinel stored in the token store short-circuits RoundTrip on all subsequent requests without re-prompting - TLS support on the callback server via `WithTLS()` option using an in-memory self-signed cert, required by providers like Slack that reject http:// loopback redirect URIs Closes #2248 Related: #416 Signed-off-by: Maxence Dominici --- agent-schema.json | 38 ++++++ max-pr/error-before-oauth-fix.md | 12 ++ pkg/config/latest/types.go | 23 ++++ pkg/config/latest/validate.go | 14 ++ pkg/config/latest/validate_test.go | 185 ++++++++++++++++++++++--- pkg/runtime/remote_runtime.go | 1 + pkg/runtime/runtime.go | 10 +- pkg/runtime/runtime_test.go | 84 ++++++++++++ pkg/teamloader/registry.go | 4 +- pkg/tools/mcp/describe_test.go | 6 +- pkg/tools/mcp/mcp.go | 16 ++- pkg/tools/mcp/oauth.go | 66 ++++++--- pkg/tools/mcp/oauth_helpers.go | 56 +++++++- pkg/tools/mcp/oauth_helpers_test.go | 46 +++++++ pkg/tools/mcp/oauth_server.go | 66 ++++++++- pkg/tools/mcp/oauth_server_test.go | 71 ++++++++++ pkg/tools/mcp/oauth_test.go | 201 ++++++++++++++++++++++++++++ pkg/tools/mcp/reconnect_test.go | 4 +- pkg/tools/mcp/remote.go | 16 ++- pkg/tools/mcp/remote_test.go | 8 +- pkg/tools/mcp/tokenstore.go | 4 + 21 files changed, 869 insertions(+), 62 deletions(-) create mode 100644 max-pr/error-before-oauth-fix.md create mode 100644 pkg/tools/mcp/oauth_helpers_test.go create mode 100644 pkg/tools/mcp/oauth_server_test.go create mode 100644 pkg/tools/mcp/oauth_test.go diff --git a/agent-schema.json b/agent-schema.json index 7622d8b86..82783ae8d 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -1191,6 +1191,10 @@ "additionalProperties": { "type": "string" } + }, + "oauth": { + "$ref": "#/definitions/RemoteOAuthConfig", + "description": "Explicit OAuth configuration for servers that do not support Dynamic Client Registration" } }, "required": [ @@ -1198,6 +1202,40 @@ ], "additionalProperties": false }, + "RemoteOAuthConfig": { + "type": "object", + "description": "Explicit OAuth client credentials for remote MCP servers that do not support Dynamic Client Registration (RFC 7591)", + "properties": { + "clientId": { + "type": "string", + "description": "Pre-registered OAuth client identifier" + }, + "clientSecret": { + "type": "string", + "description": "OAuth client secret (only needed for confidential clients)" + }, + "callbackPort": { + "type": "integer", + "description": "Port for the local OAuth redirect server. When zero, a random available port is used.", + "minimum": 0 + }, + "scopes": { + "type": "array", + "description": "Scopes to request during the authorization flow, overriding server defaults", + "items": { + "type": "string" + } + }, + "tls": { + "type": "boolean", + "description": "Enable HTTPS on the local OAuth callback server using a self-signed certificate. Required for providers (e.g. Slack) that reject http redirect URIs even for loopback addresses." + } + }, + "required": [ + "clientId" + ], + "additionalProperties": false + }, "ScriptShellToolConfig": { "type": "object", "description": "Configuration for custom shell tool", diff --git a/max-pr/error-before-oauth-fix.md b/max-pr/error-before-oauth-fix.md new file mode 100644 index 000000000..b2c8acea2 --- /dev/null +++ b/max-pr/error-before-oauth-fix.md @@ -0,0 +1,12 @@ +``` +➜ docker-agent git:(main) ✗ tail -f ~/.cagent/cagent.debug.log | grep -v -i telemetry +time=2026-03-28T15:12:16.143+01:00 level=DEBUG msg="Sending OAuth elicitation request to client" +time=2026-03-28T15:12:16.253+01:00 level=DEBUG msg="Starting unmanaged OAuth flow for server" url=https://mcp.slack.com/mcp +time=2026-03-28T15:12:16.480+01:00 level=DEBUG msg="Sending OAuth elicitation request to client" +time=2026-03-28T15:12:16.480+01:00 level=ERROR msg="Failed to initialize MCP client" error="failed to connect to MCP server: calling \"initialize\": sending \"initialize\": rejected by transport: Post \"https://mcp.slack.com/mcp\": OAuth flow failed: failed to send elicitation request: no elicitation handler configured" +time=2026-03-28T15:12:16.480+01:00 level=WARN msg="Toolset start failed; skipping" agent=root toolset=*mcp.Toolset error="failed to initialize MCP client: failed to connect to MCP server: calling \"initialize\": sending \"initialize\": rejected by transport: Post \"https://mcp.slack.com/mcp\": OAuth flow failed: failed to send elicitation request: no elicitation handler configured" +time=2026-03-28T15:12:16.480+01:00 level=DEBUG msg="Forwarding event to sidebar" event_type=*runtime.ToolsetInfoEvent +time=2026-03-28T15:12:16.483+01:00 level=DEBUG msg="Forwarding event to sidebar" event_type=*runtime.ToolsetInfoEvent + +``` + diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 54929d5f8..8b2749a93 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -714,10 +714,33 @@ func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error { return t.validate() } +// RemoteOAuthConfig holds explicit OAuth client credentials for remote MCP servers +// that do not support Dynamic Client Registration (RFC 7591). +// https://datatracker.ietf.org/doc/html/rfc7591 +type RemoteOAuthConfig struct { + // ClientID is the pre-registered OAuth client identifier. + ClientID string `json:"clientId,omitempty" yaml:"clientId,omitempty"` + // ClientSecret is optional; only needed for confidential clients. + ClientSecret string `json:"clientSecret,omitempty" yaml:"clientSecret,omitempty"` + // CallbackPort pins the local OAuth redirect server to a specific port. + // When zero, a random available port is used. + CallbackPort int `json:"callbackPort,omitempty" yaml:"callbackPort,omitempty"` + // Scopes overrides the default scopes requested during the authorization flow. + Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"` + // TLS enables HTTPS on the local OAuth callback server using an in-memory + // self-signed certificate. Required for providers (e.g. Slack) that reject + // http redirect URIs even for loopback addresses. + TLS bool `json:"tls,omitempty" yaml:"tls,omitempty"` +} + type Remote struct { URL string `json:"url"` TransportType string `json:"transport_type,omitempty"` Headers map[string]string `json:"headers,omitempty"` + // OAuth holds explicit OAuth configuration for servers that do not support + // Dynamic Client Registration. When set, the agent uses the provided + // ClientID/ClientSecret instead of attempting dynamic registration. + OAuth *RemoteOAuthConfig `json:"oauth,omitempty" yaml:"oauth,omitempty"` } // DeferConfig represents the deferred loading configuration for a toolset. diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 6d2131418..b65294f6a 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -96,6 +96,20 @@ func (t *Toolset) validate() error { if (t.Remote.URL != "" || t.Remote.TransportType != "") && t.Type != "mcp" { return errors.New("remote can only be used with type 'mcp'") } + if t.Remote.OAuth != nil { + if t.Type != "mcp" { + return errors.New("remote.oauth can only be used with type 'mcp'") + } + if t.Remote.URL == "" { + return errors.New("remote.oauth requires remote.url to be set") + } + if t.Remote.OAuth.ClientID == "" { + return errors.New("remote.oauth.clientId must be set when oauth is configured") + } + if t.Remote.OAuth.CallbackPort < 0 { + return errors.New("remote.oauth.callbackPort must be >= 0") + } + } if (len(t.Remote.Headers) > 0) && (t.Type != "mcp" && t.Type != "a2a") { return errors.New("remote headers can only be used with type 'mcp' or 'a2a'") } diff --git a/pkg/config/latest/validate_test.go b/pkg/config/latest/validate_test.go index 6c90b9be7..88089181d 100644 --- a/pkg/config/latest/validate_test.go +++ b/pkg/config/latest/validate_test.go @@ -7,14 +7,35 @@ import ( "github.com/stretchr/testify/require" ) +type validateConfigCase struct { + name string + config string + wantErr string +} + +func runValidateConfigCases(t *testing.T, tests []validateConfigCase) { + t.Helper() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + func TestToolset_Validate_LSP(t *testing.T) { t.Parallel() - tests := []struct { - name string - config string - wantErr string - }{ + tests := []validateConfigCase{ { name: "valid lsp with command", config: ` @@ -99,19 +120,149 @@ agents: }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() + runValidateConfigCases(t, tests) +} - var cfg Config - err := yaml.Unmarshal([]byte(tt.config), &cfg) +func TestToolset_Validate_RemoteOAuth(t *testing.T) { + t.Parallel() - if tt.wantErr != "" { - require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErr) - } else { - require.NoError(t, err) - } - }) + tests := []validateConfigCase{ + { + name: "valid oauth with clientId", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + remote: + url: https://mcp.slack.com/mcp + transport_type: streamable + oauth: + clientId: "my-client-id" +`, + wantErr: "", + }, + { + name: "valid oauth with all fields", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + remote: + url: https://mcp.slack.com/mcp + transport_type: streamable + oauth: + clientId: "my-client-id" + clientSecret: "my-secret" + callbackPort: 3118 + scopes: + - search:read + - chat:write +`, + wantErr: "", + }, + { + name: "valid oauth with zero callbackPort (random port)", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + remote: + url: https://mcp.slack.com/mcp + oauth: + clientId: "my-client-id" + callbackPort: 0 +`, + wantErr: "", + }, + { + name: "oauth missing clientId", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + remote: + url: https://mcp.slack.com/mcp + oauth: + clientSecret: "my-secret" +`, + wantErr: "remote.oauth.clientId must be set when oauth is configured", + }, + { + name: "oauth with negative callbackPort", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + remote: + url: https://mcp.slack.com/mcp + oauth: + clientId: "my-client-id" + callbackPort: -1 +`, + wantErr: "remote.oauth.callbackPort must be >= 0", + }, + { + name: "oauth on non-mcp toolset", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + remote: + oauth: + clientId: "my-client-id" +`, + wantErr: "remote.oauth can only be used with type 'mcp'", + }, + { + name: "oauth without remote url", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + command: some-mcp-server + remote: + oauth: + clientId: "my-client-id" +`, + wantErr: "remote.oauth requires remote.url to be set", + }, + { + name: "remote mcp without oauth passes", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: mcp + remote: + url: https://mcp.example.com/mcp + transport_type: streamable +`, + wantErr: "", + }, } + + runValidateConfigCases(t, tests) } diff --git a/pkg/runtime/remote_runtime.go b/pkg/runtime/remote_runtime.go index 9f8a0db87..97bd4272d 100644 --- a/pkg/runtime/remote_runtime.go +++ b/pkg/runtime/remote_runtime.go @@ -362,6 +362,7 @@ func (r *RemoteRuntime) handleOAuthElicitation(ctx context.Context, req *Elicita state, oauth2.S256ChallengeFromVerifier(verifier), serverURL, + nil, ) slog.Debug("Authorization URL built", "url", authURL) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 03feeb013..b45203985 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -900,8 +900,16 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen isLast := i == totalToolsets-1 - // Start the toolset if needed + // Start the toolset if needed. + // OAuth-protected toolsets are deferred — they require a live elicitation + // handler and events channel, which are only available once RunStream is + // active. They will start lazily on the first tool call. if startable, ok := toolset.(*tools.StartableToolSet); ok { + if mcpTS, isMCP := tools.As[*mcptools.Toolset](startable); isMCP && mcpTS.RequiresOAuth() { + slog.Debug("Deferring OAuth toolset start until first use", "toolset", tools.DescribeToolSet(startable)) + continue + } + if !startable.IsStarted() { if err := startable.Start(ctx); err != nil { slog.Warn("Toolset start failed; skipping", "agent", a.Name(), "toolset", fmt.Sprintf("%T", startable.ToolSet), "error", err) diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 5b6d25caf..92f15bb71 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" + mcptools "github.com/docker/docker-agent/pkg/tools/mcp" ) type stubToolSet struct { @@ -820,6 +821,89 @@ func TestProcessToolCalls_UnknownTool_ReturnsErrorResponse(t *testing.T) { assert.Contains(t, toolContent, "not available") } +func TestEmitStartupInfo_OAuthToolsetDeferred(t *testing.T) { + // An MCP toolset with an explicit oauthConfig must not be started eagerly + // during EmitStartupInfo — the OAuth flow requires a live elicitation + // handler and events channel that are only available once RunStream is active. + // The toolset should be skipped and reported as 0 tools without loading. + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + + oauthToolset := mcptools.NewRemoteToolset( + "slack", + "https://mcp.slack.com/mcp", + "streamable", + nil, + &latest.RemoteOAuthConfig{ClientID: "my-client-id"}, + ) + startable := tools.NewStartable(oauthToolset) + + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(startable), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + events := make(chan Event, 10) + rt.EmitStartupInfo(t.Context(), nil, events) + close(events) + + var toolsetEvents []*ToolsetInfoEvent + for ev := range events { + if te, ok := ev.(*ToolsetInfoEvent); ok { + toolsetEvents = append(toolsetEvents, te) + } + } + + // The toolset must not have been started. + require.False(t, startable.IsStarted(), "OAuth toolset must not be started during EmitStartupInfo") + + // Must have emitted at least one ToolsetInfoEvent with 0 tools and not loading. + require.NotEmpty(t, toolsetEvents, "expected at least one ToolsetInfoEvent") + last := toolsetEvents[len(toolsetEvents)-1] + assert.Equal(t, 0, last.AvailableTools, "OAuth toolset should report 0 tools at startup") + assert.False(t, last.Loading, "OAuth toolset should not leave the UI in loading state") +} + +func TestEmitStartupInfo_NonOAuthToolsetStartedEagerly(t *testing.T) { + // A regular (non-OAuth) MCP toolset must still be started eagerly during + // EmitStartupInfo so its tools are available immediately. + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + + stub := newStubToolSet(nil, []tools.Tool{{Name: "my_tool", Parameters: map[string]any{}}}, nil) + startable := tools.NewStartable(stub) + + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(startable), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + events := make(chan Event, 10) + rt.EmitStartupInfo(t.Context(), nil, events) + close(events) + + // The regular toolset must have been started. + require.True(t, startable.IsStarted(), "non-OAuth toolset must be started eagerly during EmitStartupInfo") + + var toolsetEvents []*ToolsetInfoEvent + for ev := range events { + if te, ok := ev.(*ToolsetInfoEvent); ok { + toolsetEvents = append(toolsetEvents, te) + } + } + + require.NotEmpty(t, toolsetEvents) + last := toolsetEvents[len(toolsetEvents)-1] + assert.Equal(t, 1, last.AvailableTools, "non-OAuth toolset tools should be counted at startup") + assert.False(t, last.Loading) +} + func TestEmitStartupInfo(t *testing.T) { // Create a simple agent with mock provider prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 3cf729053..e4c7a1660 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -253,7 +253,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon // TODO(dga): until the MCP Gateway supports oauth with docker agent, we fetch the remote url and directly connect to it. if serverSpec.Type == "remote" { - return mcp.NewRemoteToolset(toolset.Name, serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil), nil + return mcp.NewRemoteToolset(toolset.Name, serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, nil), nil } env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) @@ -294,7 +294,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon headers := expander.ExpandMap(ctx, toolset.Remote.Headers) url := expander.Expand(ctx, toolset.Remote.URL, nil) - return mcp.NewRemoteToolset(toolset.Name, url, toolset.Remote.TransportType, headers), nil + return mcp.NewRemoteToolset(toolset.Name, url, toolset.Remote.TransportType, headers, toolset.Remote.OAuth), nil default: return nil, errors.New("mcp toolset requires either ref, command, or remote configuration") diff --git a/pkg/tools/mcp/describe_test.go b/pkg/tools/mcp/describe_test.go index e20c93f26..8a8ba99e6 100644 --- a/pkg/tools/mcp/describe_test.go +++ b/pkg/tools/mcp/describe_test.go @@ -24,21 +24,21 @@ func TestToolsetDescribe_StdioNoArgs(t *testing.T) { func TestToolsetDescribe_RemoteHostAndPort(t *testing.T) { t.Parallel() - ts := NewRemoteToolset("", "http://example.com:8443/mcp/v1?key=secret", "sse", nil) + ts := NewRemoteToolset("", "http://example.com:8443/mcp/v1?key=secret", "sse", nil, nil) assert.Check(t, is.Equal(ts.Describe(), "mcp(remote host=example.com:8443 transport=sse)")) } func TestToolsetDescribe_RemoteDefaultPort(t *testing.T) { t.Parallel() - ts := NewRemoteToolset("", "https://api.example.com/mcp", "streamable", nil) + ts := NewRemoteToolset("", "https://api.example.com/mcp", "streamable", nil, nil) assert.Check(t, is.Equal(ts.Describe(), "mcp(remote host=api.example.com transport=streamable)")) } func TestToolsetDescribe_RemoteInvalidURL(t *testing.T) { t.Parallel() - ts := NewRemoteToolset("", "://bad-url", "sse", nil) + ts := NewRemoteToolset("", "://bad-url", "sse", nil, nil) assert.Check(t, is.Equal(ts.Describe(), "mcp(remote transport=sse)")) } diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 46ea4b77e..acb9e09e5 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -18,6 +18,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/tools" ) @@ -105,13 +106,13 @@ func NewToolsetCommand(name, command string, args, env []string, cwd string) *To } // NewRemoteToolset creates a new MCP toolset from a remote MCP Server. -func NewRemoteToolset(name, urlString, transport string, headers map[string]string) *Toolset { +func NewRemoteToolset(name, urlString, transport string, headers map[string]string, oauthConfig *latest.RemoteOAuthConfig) *Toolset { slog.Debug("Creating Remote MCP toolset", "url", urlString, "transport", transport, "headers", headers) desc := buildRemoteDescription(urlString, transport) return &Toolset{ name: name, - mcpClient: newRemoteClient(urlString, transport, headers, NewInMemoryTokenStore()), + mcpClient: newRemoteClient(urlString, transport, headers, NewInMemoryTokenStore(), oauthConfig), logID: urlString, description: desc, } @@ -590,6 +591,17 @@ func encodeMedia(data []byte, mimeType string) tools.MediaContent { } } +// RequiresOAuth reports whether this toolset has explicit OAuth credentials +// configured. When true, the toolset must not be started eagerly at startup +// because the OAuth flow requires a live elicitation handler and events channel, +// which are only available once RunStream is active. +func (ts *Toolset) RequiresOAuth() bool { + if c, ok := ts.mcpClient.(*remoteMCPClient); ok { + return c.oauthConfig != nil + } + return false +} + func (ts *Toolset) SetElicitationHandler(handler tools.ElicitationHandler) { ts.mcpClient.SetElicitationHandler(handler) } diff --git a/pkg/tools/mcp/oauth.go b/pkg/tools/mcp/oauth.go index 5a216e2cc..d5d10132e 100644 --- a/pkg/tools/mcp/oauth.go +++ b/pkg/tools/mcp/oauth.go @@ -16,6 +16,7 @@ import ( mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" "golang.org/x/oauth2" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/tools" ) @@ -128,18 +129,18 @@ func validateAndFillDefaults(metadata *AuthorizationServerMetadata, authServerUR metadata.AuthorizationEndpoint = cmp.Or(metadata.AuthorizationEndpoint, authServerURL+"/authorize") metadata.TokenEndpoint = cmp.Or(metadata.TokenEndpoint, authServerURL+"/token") - metadata.RegistrationEndpoint = cmp.Or(metadata.RegistrationEndpoint, authServerURL+"/register") return metadata } -// createDefaultMetadata creates minimal metadata when discovery fails +// createDefaultMetadata creates minimal metadata when discovery fails. +// RegistrationEndpoint is intentionally omitted — since discovery failed +// we cannot know whether the server supports dynamic client registration. func createDefaultMetadata(authServerURL string) *AuthorizationServerMetadata { return &AuthorizationServerMetadata{ Issuer: authServerURL, AuthorizationEndpoint: authServerURL + "/authorize", TokenEndpoint: authServerURL + "/token", - RegistrationEndpoint: authServerURL + "/register", ResponseTypesSupported: []string{"code"}, ResponseModesSupported: []string{"query", "fragment"}, GrantTypesSupported: []string{"authorization_code"}, @@ -161,10 +162,11 @@ func resourceMetadataFromWWWAuth(wwwAuth string) string { type oauthTransport struct { base http.RoundTripper // TODO(rumpl): remove client reference, we need to find a better way to send elicitation requests - client *remoteMCPClient - tokenStore OAuthTokenStore - baseURL string - managed bool + client *remoteMCPClient + tokenStore OAuthTokenStore + baseURL string + managed bool + oauthConfig *latest.RemoteOAuthConfig } func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -180,8 +182,13 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { reqClone := req.Clone(req.Context()) - if token, err := t.tokenStore.GetToken(t.baseURL); err == nil && !token.IsExpired() { - reqClone.Header.Set("Authorization", "Bearer "+token.AccessToken) + if token, err := t.tokenStore.GetToken(t.baseURL); err == nil { + if token.UserDeclined { + return nil, errors.New("OAuth authorization was declined for this session") + } + if !token.IsExpired() { + reqClone.Header.Set("Authorization", "Bearer "+token.AccessToken) + } } resp, err := t.base.RoundTrip(reqClone) @@ -212,7 +219,9 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { // handleOAuthFlow performs the OAuth flow when a 401 response is received func (t *oauthTransport) handleOAuthFlow(ctx context.Context, authServer, wwwAuth string) error { - if t.managed { + // When explicit OAuth credentials are configured, always use the managed flow + // regardless of the managed flag — we have everything needed to drive it ourselves. + if t.managed || (t.oauthConfig != nil && t.oauthConfig.ClientID != "") { return t.handleManagedOAuthFlow(ctx, authServer, wwwAuth) } @@ -253,7 +262,16 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, } slog.Debug("Creating OAuth callback server") - callbackServer, err := NewCallbackServer() + var callbackServer *CallbackServer + var callbackOpts []CallbackServerOption + if t.oauthConfig != nil && t.oauthConfig.TLS { + callbackOpts = append(callbackOpts, WithTLS()) + } + if t.oauthConfig != nil && t.oauthConfig.CallbackPort > 0 { + callbackServer, err = NewCallbackServerWithOptions(t.oauthConfig.CallbackPort, callbackOpts...) + } else { + callbackServer, err = NewCallbackServerWithOptions(0, callbackOpts...) + } if err != nil { return fmt.Errorf("failed to create callback server: %w", err) } @@ -275,17 +293,21 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, var clientID string var clientSecret string - if authServerMetadata.RegistrationEndpoint != "" { + switch { + case t.oauthConfig != nil && t.oauthConfig.ClientID != "": + // Use pre-registered credentials — skip dynamic registration entirely. + clientID = t.oauthConfig.ClientID + clientSecret = t.oauthConfig.ClientSecret + slog.Debug("Using explicit OAuth client credentials") + case authServerMetadata.RegistrationEndpoint != "": slog.Debug("Attempting dynamic client registration") clientID, clientSecret, err = RegisterClient(ctx, authServerMetadata, redirectURI, nil) if err != nil { - slog.Debug("Dynamic registration failed", "error", err) - // TODO(rumpl): fall back to requesting client ID from user - return err + slog.Debug("Dynamic registration failed, provide explicit oauth.clientId in config", "error", err) + return fmt.Errorf("dynamic client registration failed; set remote.oauth.clientId in your config: %w", err) } - } else { - // TODO(rumpl): fall back to requesting client ID from user - return errors.New("authorization server does not support dynamic client registration") + default: + return errors.New("authorization server does not support dynamic client registration; set remote.oauth.clientId in your config") } state, err := GenerateState() @@ -296,6 +318,11 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, callbackServer.SetExpectedState(state) verifier := GeneratePKCEVerifier() + scopes := authServerMetadata.ScopesSupported + if t.oauthConfig != nil && len(t.oauthConfig.Scopes) > 0 { + scopes = t.oauthConfig.Scopes + } + authURL := BuildAuthorizationURL( authServerMetadata.AuthorizationEndpoint, clientID, @@ -303,6 +330,7 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, state, oauth2.S256ChallengeFromVerifier(verifier), t.baseURL, + scopes, ) result, err := t.client.requestElicitation(ctx, &mcpsdk.ElicitParams{ @@ -320,6 +348,7 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, slog.Debug("Elicitation response received", "result", result) if result.Action != tools.ElicitationActionAccept { + _ = t.tokenStore.StoreToken(t.baseURL, &OAuthToken{UserDeclined: true}) return errors.New("user declined OAuth authorization") } @@ -415,6 +444,7 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe slog.Debug("Received elicitation response from client", "action", result.Action) if result.Action != tools.ElicitationActionAccept { + _ = t.tokenStore.StoreToken(t.baseURL, &OAuthToken{UserDeclined: true}) return errors.New("OAuth flow declined or cancelled by client") } if result.Content == nil { diff --git a/pkg/tools/mcp/oauth_helpers.go b/pkg/tools/mcp/oauth_helpers.go index 588bcbe08..fb94f1a1c 100644 --- a/pkg/tools/mcp/oauth_helpers.go +++ b/pkg/tools/mcp/oauth_helpers.go @@ -2,12 +2,20 @@ package mcp import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/base64" "encoding/json" + "encoding/pem" "errors" "fmt" "io" + "math/big" + "net" "net/http" "net/url" "strings" @@ -28,7 +36,7 @@ func GenerateState() (string, error) { } // BuildAuthorizationURL builds the OAuth authorization URL with PKCE -func BuildAuthorizationURL(authEndpoint, clientID, redirectURI, state, codeChallenge, resourceURL string) string { +func BuildAuthorizationURL(authEndpoint, clientID, redirectURI, state, codeChallenge, resourceURL string, scopes []string) string { params := url.Values{} params.Set("response_type", "code") params.Set("client_id", clientID) @@ -37,6 +45,9 @@ func BuildAuthorizationURL(authEndpoint, clientID, redirectURI, state, codeChall params.Set("code_challenge", codeChallenge) params.Set("code_challenge_method", "S256") params.Set("resource", resourceURL) // RFC 8707: Resource Indicators + if len(scopes) > 0 { + params.Set("scope", strings.Join(scopes, " ")) + } return authEndpoint + "?" + params.Encode() } @@ -159,3 +170,46 @@ func RegisterClient(ctx context.Context, authMetadata *AuthorizationServerMetada func GeneratePKCEVerifier() string { return oauth2.GenerateVerifier() } + +// selfSignedTLSConfig generates an in-memory self-signed TLS certificate for +// 127.0.0.1 / localhost. The cert is valid for 24 hours — long enough for any +// OAuth flow but short enough to limit exposure if somehow leaked. +func selfSignedTLSConfig() (*tls.Config, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + } + + pub := key.Public() + certDER, err := x509.CreateCertificate(rand.Reader, template, template, pub, key) + if err != nil { + return nil, err + } + + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return nil, err + } + + cert, err := tls.X509KeyPair( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), + pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}), + ) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }, nil +} diff --git a/pkg/tools/mcp/oauth_helpers_test.go b/pkg/tools/mcp/oauth_helpers_test.go new file mode 100644 index 000000000..967655c69 --- /dev/null +++ b/pkg/tools/mcp/oauth_helpers_test.go @@ -0,0 +1,46 @@ +package mcp + +import ( + "crypto/tls" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSelfSignedTLSConfig(t *testing.T) { + t.Parallel() + + cfg, err := selfSignedTLSConfig() + require.NoError(t, err) + require.NotNil(t, cfg) + assert.Len(t, cfg.Certificates, 1) + assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion) +} + +func TestBuildAuthorizationURL(t *testing.T) { + rawURL := BuildAuthorizationURL( + "https://auth.example.com/authorize", + "my-client-id", + "http://127.0.0.1:3118/callback", + "random-state-value", + "code-challenge-value", + "https://mcp.example.com/mcp", + []string{"search:read", "chat:write"}, + ) + + parsed, err := url.Parse(rawURL) + require.NoError(t, err) + + q := parsed.Query() + assert.Equal(t, "https://auth.example.com/authorize", parsed.Scheme+"://"+parsed.Host+parsed.Path) + assert.Equal(t, "code", q.Get("response_type")) + assert.Equal(t, "my-client-id", q.Get("client_id")) + assert.Equal(t, "http://127.0.0.1:3118/callback", q.Get("redirect_uri")) + assert.Equal(t, "random-state-value", q.Get("state")) + assert.Equal(t, "code-challenge-value", q.Get("code_challenge")) + assert.Equal(t, "S256", q.Get("code_challenge_method")) + assert.Equal(t, "https://mcp.example.com/mcp", q.Get("resource")) + assert.Equal(t, "search:read chat:write", q.Get("scope")) +} diff --git a/pkg/tools/mcp/oauth_server.go b/pkg/tools/mcp/oauth_server.go index b66f30b6a..f9bdc9db2 100644 --- a/pkg/tools/mcp/oauth_server.go +++ b/pkg/tools/mcp/oauth_server.go @@ -2,6 +2,7 @@ package mcp import ( "context" + "crypto/tls" "errors" "fmt" "log/slog" @@ -16,6 +17,7 @@ type CallbackServer struct { server *http.Server listener net.Listener mu sync.Mutex + useTLS bool // Channels for communicating the authorization code and state codeCh chan string @@ -26,12 +28,47 @@ type CallbackServer struct { expectedState string } -// NewCallbackServer creates a new OAuth callback server -func NewCallbackServer() (*CallbackServer, error) { - // Find an available port - listener, err := net.Listen("tcp", "127.0.0.1:0") +// CallbackServerOption configures a CallbackServer. +type CallbackServerOption func(*CallbackServer) + +// WithTLS enables HTTPS on the callback server using an in-memory self-signed +// certificate. Use this when the OAuth provider requires an https redirect_uri +// even for loopback addresses (e.g. Slack). +func WithTLS() CallbackServerOption { + return func(cs *CallbackServer) { + cs.useTLS = true + } +} + +// NewCallbackServer creates a new OAuth callback server. +// An optional port can be provided to bind to a specific local port. +// When no port is given, or port is 0, a random available port is used. +func NewCallbackServer(port ...int) (*CallbackServer, error) { + return newCallbackServer(port, nil) +} + +// NewCallbackServerWithOptions creates a new OAuth callback server with functional options. +// An optional port can be provided as the first argument. +func NewCallbackServerWithOptions(port int, opts ...CallbackServerOption) (*CallbackServer, error) { + var ports []int + if port > 0 { + ports = []int{port} + } + return newCallbackServer(ports, opts) +} + +func newCallbackServer(port []int, opts []CallbackServerOption) (*CallbackServer, error) { + addr := "127.0.0.1:0" + if len(port) > 0 && port[0] > 0 { + addr = fmt.Sprintf("127.0.0.1:%d", port[0]) + } + + listener, err := net.Listen("tcp", addr) if err != nil { - return nil, fmt.Errorf("failed to find available port: %w", err) + if len(port) > 0 && port[0] > 0 { + return nil, fmt.Errorf("callback port %d is already in use: %w", port[0], err) + } + return nil, fmt.Errorf("failed to find available port for callback: %w", err) } cs := &CallbackServer{ @@ -41,6 +78,10 @@ func NewCallbackServer() (*CallbackServer, error) { errCh: make(chan error, 1), } + for _, opt := range opts { + opt(cs) + } + mux := http.NewServeMux() mux.HandleFunc("/callback", cs.handleCallback) @@ -50,6 +91,15 @@ func NewCallbackServer() (*CallbackServer, error) { WriteTimeout: 10 * time.Second, } + if cs.useTLS { + tlsCfg, tlsErr := selfSignedTLSConfig() + if tlsErr != nil { + _ = listener.Close() + return nil, fmt.Errorf("failed to generate self-signed certificate: %w", tlsErr) + } + cs.listener = tls.NewListener(listener, tlsCfg) + } + return cs, nil } @@ -66,7 +116,11 @@ func (cs *CallbackServer) Start() error { func (cs *CallbackServer) GetRedirectURI() string { addr := cs.listener.Addr().String() - return fmt.Sprintf("http://%s/callback", addr) + scheme := "http" + if cs.useTLS { + scheme = "https" + } + return fmt.Sprintf("%s://%s/callback", scheme, addr) } func (cs *CallbackServer) SetExpectedState(state string) { diff --git a/pkg/tools/mcp/oauth_server_test.go b/pkg/tools/mcp/oauth_server_test.go new file mode 100644 index 000000000..c057e0001 --- /dev/null +++ b/pkg/tools/mcp/oauth_server_test.go @@ -0,0 +1,71 @@ +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCallbackServer_FixedPort(t *testing.T) { + t.Parallel() + + cs, err := NewCallbackServer(3118) + require.NoError(t, err) + require.NotNil(t, cs) + + assert.Equal(t, "http://127.0.0.1:3118/callback", cs.GetRedirectURI()) + + _ = cs.Shutdown(t.Context()) +} + +func TestNewCallbackServer_RandomPort(t *testing.T) { + t.Parallel() + + cs, err := NewCallbackServer() + require.NoError(t, err) + require.NotNil(t, cs) + + assert.Contains(t, cs.GetRedirectURI(), "http://127.0.0.1:") + assert.Contains(t, cs.GetRedirectURI(), "/callback") + + _ = cs.Shutdown(t.Context()) +} + +func TestNewCallbackServer_PortAlreadyInUse(t *testing.T) { + t.Parallel() + + first, err := NewCallbackServer(3119) + require.NoError(t, err) + defer func() { _ = first.Shutdown(t.Context()) }() + + _, err = NewCallbackServer(3119) + require.Error(t, err) + assert.Contains(t, err.Error(), "callback port 3119 is already in use") +} + +func TestGetRedirectURI(t *testing.T) { + t.Parallel() + + t.Run("http when tls is false", func(t *testing.T) { + t.Parallel() + + cs, err := NewCallbackServer() + require.NoError(t, err) + defer func() { _ = cs.Shutdown(t.Context()) }() + + assert.Contains(t, cs.GetRedirectURI(), "http://") + assert.Contains(t, cs.GetRedirectURI(), "/callback") + }) + + t.Run("https when tls is true", func(t *testing.T) { + t.Parallel() + + cs, err := NewCallbackServerWithOptions(0, WithTLS()) + require.NoError(t, err) + defer func() { _ = cs.Shutdown(t.Context()) }() + + assert.Contains(t, cs.GetRedirectURI(), "https://") + assert.Contains(t, cs.GetRedirectURI(), "/callback") + }) +} diff --git a/pkg/tools/mcp/oauth_test.go b/pkg/tools/mcp/oauth_test.go new file mode 100644 index 000000000..d0ad8e158 --- /dev/null +++ b/pkg/tools/mcp/oauth_test.go @@ -0,0 +1,201 @@ +package mcp + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/tools" +) + +func TestValidateAndFillDefaults(t *testing.T) { + t.Parallel() + + t.Run("fills missing endpoints from authServerURL", func(t *testing.T) { + t.Parallel() + + metadata := &AuthorizationServerMetadata{} + result := validateAndFillDefaults(metadata, "https://auth.example.com") + + assert.Equal(t, "https://auth.example.com", result.Issuer) + assert.Equal(t, "https://auth.example.com/authorize", result.AuthorizationEndpoint) + assert.Equal(t, "https://auth.example.com/token", result.TokenEndpoint) + }) + + t.Run("preserves existing endpoints", func(t *testing.T) { + t.Parallel() + + metadata := &AuthorizationServerMetadata{ + Issuer: "https://issuer.example.com", + AuthorizationEndpoint: "https://auth.example.com/oauth/authorize", + TokenEndpoint: "https://auth.example.com/oauth/token", + } + result := validateAndFillDefaults(metadata, "https://auth.example.com") + + assert.Equal(t, "https://issuer.example.com", result.Issuer) + assert.Equal(t, "https://auth.example.com/oauth/authorize", result.AuthorizationEndpoint) + assert.Equal(t, "https://auth.example.com/oauth/token", result.TokenEndpoint) + }) + + t.Run("preserves registration endpoint when server advertises it", func(t *testing.T) { + t.Parallel() + + metadata := &AuthorizationServerMetadata{ + RegistrationEndpoint: "https://auth.example.com/register", + } + result := validateAndFillDefaults(metadata, "https://auth.example.com") + + assert.Equal(t, "https://auth.example.com/register", result.RegistrationEndpoint) + }) + + t.Run("does not fabricate registration endpoint when server omits it", func(t *testing.T) { + t.Parallel() + + // Servers like Slack do not advertise a registration_endpoint + // because they do not support Dynamic Client Registration (RFC 7591). + // validateAndFillDefaults must not invent one — doing so causes a + // guaranteed 404/302 and misleads the caller into attempting registration. + metadata := &AuthorizationServerMetadata{} + result := validateAndFillDefaults(metadata, "https://mcp.slack.com") + + assert.Empty(t, result.RegistrationEndpoint) + }) + + t.Run("fills default response types when empty", func(t *testing.T) { + t.Parallel() + + metadata := &AuthorizationServerMetadata{} + result := validateAndFillDefaults(metadata, "https://auth.example.com") + + assert.Equal(t, []string{"code"}, result.ResponseTypesSupported) + }) + + t.Run("preserves existing response types", func(t *testing.T) { + t.Parallel() + + metadata := &AuthorizationServerMetadata{ + ResponseTypesSupported: []string{"code", "token"}, + } + result := validateAndFillDefaults(metadata, "https://auth.example.com") + + assert.Equal(t, []string{"code", "token"}, result.ResponseTypesSupported) + }) +} + +// TestOAuthTransport_UserDeclined verifies that once a user declines the OAuth +// prompt, all subsequent requests are rejected immediately without re-prompting. +func TestOAuthTransport_UserDeclined(t *testing.T) { + t.Parallel() + + var promptCount atomic.Int32 + + // Minimal server that: + // - returns 401 on the MCP endpoint to trigger the OAuth flow + // - returns 404 for /.well-known/oauth-protected-resource (acceptable; flow continues) + // - returns minimal auth server metadata for /.well-known/oauth-authorization-server + // so the managed flow reaches the elicitation prompt + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/oauth-protected-resource": + w.WriteHeader(http.StatusNotFound) + case "/.well-known/oauth-authorization-server": + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"issuer":%q,"authorization_endpoint":%q,"token_endpoint":%q}`, + "http://"+r.Host, "http://"+r.Host+"/authorize", "http://"+r.Host+"/token") + default: + w.Header().Set("WWW-Authenticate", `Bearer error="unauthorized"`) + w.WriteHeader(http.StatusUnauthorized) + } + })) + defer server.Close() + + client := newRemoteClient(server.URL, "streamable", nil, NewInMemoryTokenStore(), &latest.RemoteOAuthConfig{ClientID: "test-client-id"}) + client.SetElicitationHandler(func(_ context.Context, _ *gomcp.ElicitParams) (tools.ElicitationResult, error) { + promptCount.Add(1) + return tools.ElicitationResult{Action: tools.ElicitationActionDecline}, nil + }) + + httpClient := client.createHTTPClient() + + // First request — OAuth flow fires, user declines once. + req1, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) + require.NoError(t, err) + resp1, err := httpClient.Do(req1) + if resp1 != nil { + resp1.Body.Close() + } + require.Error(t, err) + assert.Contains(t, err.Error(), "declined") + assert.Equal(t, int32(1), promptCount.Load(), "elicitation should fire exactly once") + + // Second request — UserDeclined sentinel must short-circuit without re-prompting. + req2, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) + require.NoError(t, err) + resp2, err := httpClient.Do(req2) + if resp2 != nil { + resp2.Body.Close() + } + require.Error(t, err) + assert.Contains(t, err.Error(), "declined") + assert.Equal(t, int32(1), promptCount.Load(), "elicitation must not fire again after user declined") + + // Third request — same guarantee holds for any number of follow-up calls. + req3, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) + require.NoError(t, err) + resp3, err := httpClient.Do(req3) + if resp3 != nil { + resp3.Body.Close() + } + require.Error(t, err) + assert.Equal(t, int32(1), promptCount.Load(), "elicitation count must stay at 1 regardless of further requests") +} + +// TestOAuthUserDeclinedSentinel groups unit-level checks on the UserDeclined +// sentinel: that it is written to the token store correctly, that it does not +// bleed across server URLs, and that IsExpired does not misclassify it. +func TestOAuthUserDeclinedSentinel(t *testing.T) { + t.Parallel() + + t.Run("sentinel written to store after decline", func(t *testing.T) { + t.Parallel() + + store := NewInMemoryTokenStore() + require.NoError(t, store.StoreToken("https://mcp.example.com", &OAuthToken{UserDeclined: true})) + + token, err := store.GetToken("https://mcp.example.com") + require.NoError(t, err) + assert.True(t, token.UserDeclined) + assert.Empty(t, token.AccessToken, "declined sentinel must not carry an access token") + assert.False(t, token.IsExpired(), "sentinel with zero ExpiresAt must not be considered expired") + }) + + t.Run("decline is scoped to the declined server URL only", func(t *testing.T) { + t.Parallel() + + store := NewInMemoryTokenStore() + require.NoError(t, store.StoreToken("https://mcp.slack.com", &OAuthToken{UserDeclined: true})) + + declined, err := store.GetToken("https://mcp.slack.com") + require.NoError(t, err) + assert.True(t, declined.UserDeclined) + + _, err = store.GetToken("https://mcp.github.com") + assert.Error(t, err, "unrelated server must have no token in the store") + }) + + t.Run("IsExpired returns false for sentinel with zero ExpiresAt", func(t *testing.T) { + t.Parallel() + + sentinel := &OAuthToken{UserDeclined: true} + assert.False(t, sentinel.IsExpired()) + assert.Empty(t, sentinel.AccessToken) + }) +} diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go index eeacc52a6..91ceb1d8a 100644 --- a/pkg/tools/mcp/reconnect_test.go +++ b/pkg/tools/mcp/reconnect_test.go @@ -121,7 +121,7 @@ func TestRemoteReconnectAfterServerRestart(t *testing.T) { // --- Step 1–2: Start first server, connect toolset --- shutdown1 := startServer(t) - ts := NewRemoteToolset("test", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil) + ts := NewRemoteToolset("test", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil, nil) require.NoError(t, ts.Start(t.Context())) toolList, err := ts.Tools(t.Context()) @@ -184,7 +184,7 @@ func TestRemoteReconnectRefreshesTools(t *testing.T) { // --- Start server v1 with tools "alpha" + "shared" --- shutdown1 := startMCPServer(t, addr, alphaTool, sharedTool) - ts := NewRemoteToolset("ns", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil) + ts := NewRemoteToolset("ns", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil, nil) // Track toolsChangedHandler invocations. toolsChangedCh := make(chan struct{}, 1) diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index 83269883e..53abc858b 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -8,6 +8,7 @@ import ( gomcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/upstream" ) @@ -19,9 +20,10 @@ type remoteMCPClient struct { headers map[string]string tokenStore OAuthTokenStore managed bool + oauthConfig *latest.RemoteOAuthConfig } -func newRemoteClient(url, transportType string, headers map[string]string, tokenStore OAuthTokenStore) *remoteMCPClient { +func newRemoteClient(url, transportType string, headers map[string]string, tokenStore OAuthTokenStore, oauthConfig *latest.RemoteOAuthConfig) *remoteMCPClient { slog.Debug("Creating remote MCP client", "url", url, "transport", transportType, "headers", headers) if tokenStore == nil { @@ -33,6 +35,7 @@ func newRemoteClient(url, transportType string, headers map[string]string, token transportType: transportType, headers: headers, tokenStore: tokenStore, + oauthConfig: oauthConfig, } } @@ -102,11 +105,12 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client { // Then wrap with OAuth support transport = &oauthTransport{ - base: transport, - client: c, - tokenStore: c.tokenStore, - baseURL: c.url, - managed: c.managed, + base: transport, + client: c, + tokenStore: c.tokenStore, + baseURL: c.url, + managed: c.managed, + oauthConfig: c.oauthConfig, } return &http.Client{ diff --git a/pkg/tools/mcp/remote_test.go b/pkg/tools/mcp/remote_test.go index 4655bf60f..f41a54143 100644 --- a/pkg/tools/mcp/remote_test.go +++ b/pkg/tools/mcp/remote_test.go @@ -42,7 +42,7 @@ func TestRemoteClientCustomHeaders(t *testing.T) { "Authorization": "Bearer custom-token", } - client := newRemoteClient(server.URL, "sse", expectedHeaders, NewInMemoryTokenStore()) + client := newRemoteClient(server.URL, "sse", expectedHeaders, NewInMemoryTokenStore(), nil) // Try to initialize (which will make the HTTP request) // We don't care if it succeeds or fails, we just need it to make the request @@ -91,7 +91,7 @@ func TestRemoteClientHeadersWithStreamable(t *testing.T) { "X-Custom-Auth": "custom-auth-value", } - client := newRemoteClient(server.URL, "streamable", expectedHeaders, NewInMemoryTokenStore()) + client := newRemoteClient(server.URL, "streamable", expectedHeaders, NewInMemoryTokenStore(), nil) // Try to initialize _, _ = client.Initialize(t.Context(), nil) @@ -131,7 +131,7 @@ func TestRemoteClientNoHeaders(t *testing.T) { defer server.Close() // Create remote client without custom headers (nil) - client := newRemoteClient(server.URL, "sse", nil, NewInMemoryTokenStore()) + client := newRemoteClient(server.URL, "sse", nil, NewInMemoryTokenStore(), nil) _, _ = client.Initialize(t.Context(), nil) @@ -167,7 +167,7 @@ func TestRemoteClientEmptyHeaders(t *testing.T) { defer server.Close() // Create remote client with empty headers map - client := newRemoteClient(server.URL, "sse", map[string]string{}, NewInMemoryTokenStore()) + client := newRemoteClient(server.URL, "sse", map[string]string{}, NewInMemoryTokenStore(), nil) _, _ = client.Initialize(t.Context(), nil) diff --git a/pkg/tools/mcp/tokenstore.go b/pkg/tools/mcp/tokenstore.go index 340d73cd8..38934fd25 100644 --- a/pkg/tools/mcp/tokenstore.go +++ b/pkg/tools/mcp/tokenstore.go @@ -24,6 +24,10 @@ type OAuthToken struct { RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` ExpiresAt time.Time `json:"expires_at"` + // UserDeclined is set to true when the user explicitly declines the OAuth + // authorization prompt. It acts as a session-scoped sentinel that prevents + // the flow from being re-triggered on subsequent requests. + UserDeclined bool `json:"user_declined,omitempty"` } // IsExpired checks if the token is expired From 15239bf3e92f0856bd944dd8ecefc0476dc204ec Mon Sep 17 00:00:00 2001 From: Elsha <38140638+iElsha@users.noreply.github.com> Date: Sat, 28 Mar 2026 20:49:14 +0100 Subject: [PATCH 2/2] Delete max-pr/error-before-oauth-fix.md --- max-pr/error-before-oauth-fix.md | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 max-pr/error-before-oauth-fix.md diff --git a/max-pr/error-before-oauth-fix.md b/max-pr/error-before-oauth-fix.md deleted file mode 100644 index b2c8acea2..000000000 --- a/max-pr/error-before-oauth-fix.md +++ /dev/null @@ -1,12 +0,0 @@ -``` -➜ docker-agent git:(main) ✗ tail -f ~/.cagent/cagent.debug.log | grep -v -i telemetry -time=2026-03-28T15:12:16.143+01:00 level=DEBUG msg="Sending OAuth elicitation request to client" -time=2026-03-28T15:12:16.253+01:00 level=DEBUG msg="Starting unmanaged OAuth flow for server" url=https://mcp.slack.com/mcp -time=2026-03-28T15:12:16.480+01:00 level=DEBUG msg="Sending OAuth elicitation request to client" -time=2026-03-28T15:12:16.480+01:00 level=ERROR msg="Failed to initialize MCP client" error="failed to connect to MCP server: calling \"initialize\": sending \"initialize\": rejected by transport: Post \"https://mcp.slack.com/mcp\": OAuth flow failed: failed to send elicitation request: no elicitation handler configured" -time=2026-03-28T15:12:16.480+01:00 level=WARN msg="Toolset start failed; skipping" agent=root toolset=*mcp.Toolset error="failed to initialize MCP client: failed to connect to MCP server: calling \"initialize\": sending \"initialize\": rejected by transport: Post \"https://mcp.slack.com/mcp\": OAuth flow failed: failed to send elicitation request: no elicitation handler configured" -time=2026-03-28T15:12:16.480+01:00 level=DEBUG msg="Forwarding event to sidebar" event_type=*runtime.ToolsetInfoEvent -time=2026-03-28T15:12:16.483+01:00 level=DEBUG msg="Forwarding event to sidebar" event_type=*runtime.ToolsetInfoEvent - -``` -