diff --git a/config/config.go b/config/config.go index 370a68b..f4107e4 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ type Anthropic struct { APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool + ExtraHeaders map[string]string } type AWSBedrock struct { diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 6a5f512..c61a3a7 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -181,6 +181,12 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) + // Add extra headers if configured. + // Some providers require additional headers that are not added by the SDK. + for key, value := range i.cfg.ExtraHeaders { + opts = append(opts, option.WithHeader(key, value)) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/provider/anthropic.go b/provider/anthropic.go index 5195cec..e682fdb 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -19,6 +19,13 @@ import ( "go.opentelemetry.io/otel/trace" ) +// anthropicForwardHeaders lists headers from incoming requests that should be +// forwarded to the Anthropic API. +// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 +var anthropicForwardHeaders = []string{ + "Anthropic-Beta", +} + var _ Provider = &Anthropic{} // Anthropic allows for interactions with the Anthropic API. @@ -100,11 +107,14 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr return nil, fmt.Errorf("unmarshal request body: %w", err) } + cfg := p.cfg + cfg.ExtraHeaders = extractAnthropicHeaders(r) + var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil @@ -137,3 +147,16 @@ func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker { func (p *Anthropic) APIDumpDir() string { return p.cfg.APIDumpDir } + +// extractAnthropicHeaders extracts headers required by the Anthropic API from +// the incoming request. +// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 +func extractAnthropicHeaders(r *http.Request) map[string]string { + headers := make(map[string]string, len(anthropicForwardHeaders)) + for _, h := range anthropicForwardHeaders { + if v := r.Header.Get(h); v != "" { + headers[h] = v + } + } + return headers +} diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index 49b2c6c..924c0f9 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -1,12 +1,171 @@ package provider import ( + "bytes" "net/http" + "net/http/httptest" "testing" + "cdr.dev/slog/v3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" ) +func TestAnthropic_CreateInterceptor(t *testing.T) { + t.Parallel() + + provider := NewAnthropic(config.Anthropic{Key: "test-key"}, nil) + + t.Run("Messages_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.False(t, interceptor.Streaming()) + }) + + t.Run("Messages_StreamingRequest_StreamingInterceptor", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": true}` + req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.NoError(t, err) + require.NotNil(t, interceptor) + assert.True(t, interceptor.Streaming()) + }) + + t.Run("Messages_InvalidRequestBody", func(t *testing.T) { + t.Parallel() + + body := `invalid json` + req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.Error(t, err) + require.Nil(t, interceptor) + assert.Contains(t, err.Error(), "unmarshal request body") + }) + + t.Run("Messages_ForwardsAnthropicBetaHeaderToUpstream", func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + // Mock upstream that captures headers. + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg-123","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + })) + t.Cleanup(mockUpstream.Close) + + provider := NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }, nil) + + // Use a realistic multi-beta value as sent by Claude Code clients. + betaHeader := "claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24" + + body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) + req.Header.Set("Anthropic-Beta", betaHeader) + req.Header.Set("X-Custom-Header", "should-not-forward") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, routeMessages, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify the full Anthropic-Beta header (all betas) was forwarded unchanged. + assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta")) + + // Verify non-Anthropic headers are not forwarded. + assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Anthropic headers should not be forwarded") + }) + + t.Run("UnknownRoute", func(t *testing.T) { + t.Parallel() + + body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/anthropic/unknown/route", bytes.NewBufferString(body)) + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + + require.ErrorIs(t, err, UnknownRoute) + require.Nil(t, interceptor) + }) +} + +func TestExtractAnthropicHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headers map[string]string + expected map[string]string + }{ + { + name: "no headers", + headers: map[string]string{}, + expected: map[string]string{}, + }, + { + name: "single beta", + headers: map[string]string{"Anthropic-Beta": "claude-code-20250219"}, + expected: map[string]string{"Anthropic-Beta": "claude-code-20250219"}, + }, + { + name: "multiple betas in single header", + headers: map[string]string{"Anthropic-Beta": "claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24"}, + expected: map[string]string{"Anthropic-Beta": "claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24"}, + }, + { + name: "ignores other headers", + headers: map[string]string{"Anthropic-Beta": "claude-code-20250219,context-management-2025-06-27", "X-Api-Key": "secret"}, + expected: map[string]string{"Anthropic-Beta": "claude-code-20250219,context-management-2025-06-27"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, "/", nil) + for header, value := range tc.headers { + req.Header.Set(header, value) + } + + result := extractAnthropicHeaders(req) + assert.Equal(t, tc.expected, result) + }) + } +} + func Test_anthropicIsFailure(t *testing.T) { t.Parallel()