From dff66d95d075a568fc3d4c356d9b431f3b2d30b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 15:45:38 +0000 Subject: [PATCH 01/32] refactor: create internal/integrationtest package infrastructure - Move upstream.go and mockmcp.go from internal/testutil to internal/integrationtest - Create bridge.go with NewBridgeTestServer, NewLogger, DefaultActorID, DefaultTracer - Create requests.go with shared request/config helpers (promoted from test files) - internal/testutil retains only lightweight mocks (MockRecorder, MockProvider) --- internal/integrationtest/bridge.go | 151 ++++++++++++++++++ .../{testutil => integrationtest}/mockmcp.go | 2 +- internal/integrationtest/requests.go | 61 +++++++ .../{testutil => integrationtest}/upstream.go | 2 +- 4 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 internal/integrationtest/bridge.go rename internal/{testutil => integrationtest}/mockmcp.go (99%) create mode 100644 internal/integrationtest/requests.go rename internal/{testutil => integrationtest}/upstream.go (99%) diff --git a/internal/integrationtest/bridge.go b/internal/integrationtest/bridge.go new file mode 100644 index 0000000..0c4da14 --- /dev/null +++ b/internal/integrationtest/bridge.go @@ -0,0 +1,151 @@ +package integrationtest + +import ( + "context" + "net" + "net/http/httptest" + "testing" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/metrics" + "github.com/coder/aibridge/recorder" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" +) + +// DefaultActorID is the actor ID used by default in test servers. +const DefaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" + +// DefaultTracer is the default OTel tracer used in integration tests. +var DefaultTracer = otel.Tracer("integrationtest") + +// NewLogger creates a test logger at Debug level. +// Eliminates the repeated slogtest.Make(t, &slogtest.Options{...}).Leveled(slog.LevelDebug) pattern. +func NewLogger(t *testing.T, opts ...*slogtest.Options) slog.Logger { + t.Helper() + var o *slogtest.Options + if len(opts) > 0 { + o = opts[0] + } else { + o = &slogtest.Options{} + } + return slogtest.Make(t, o).Leveled(slog.LevelDebug) +} + +// BridgeTestServer wraps an httptest.Server running a RequestBridge. +type BridgeTestServer struct { + *httptest.Server + Recorder *testutil.MockRecorder + Bridge *aibridge.RequestBridge +} + +// BridgeOption configures a [BridgeTestServer]. +type BridgeOption func(*bridgeConfig) + +type bridgeConfig struct { + metrics *metrics.Metrics + tracer trace.Tracer + mcpProxy mcp.ServerProxier + userID string + metadata recorder.Metadata + logger slog.Logger + loggerSet bool + wrapRecorder bool +} + +// WithMetrics sets the Prometheus metrics for the bridge. +func WithMetrics(m *metrics.Metrics) BridgeOption { + return func(c *bridgeConfig) { c.metrics = m } +} + +// WithTracer overrides the default tracer. +func WithTracer(t trace.Tracer) BridgeOption { + return func(c *bridgeConfig) { c.tracer = t } +} + +// WithMCP sets the MCP server proxier (default: NoopMCPManager). +func WithMCP(p mcp.ServerProxier) BridgeOption { + return func(c *bridgeConfig) { c.mcpProxy = p } +} + +// WithActor sets the actor ID and metadata for the BaseContext. +func WithActor(id string, md recorder.Metadata) BridgeOption { + return func(c *bridgeConfig) { c.userID = id; c.metadata = md } +} + +// WithLogger overrides the default slogtest debug logger. +func WithLogger(l slog.Logger) BridgeOption { + return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } +} + +// WithWrappedRecorder wraps the MockRecorder through aibridge.NewRecorder +// (the production recorder wrapper). Use when testing the recorder pipeline. +func WithWrappedRecorder() BridgeOption { + return func(c *bridgeConfig) { c.wrapRecorder = true } +} + +// NewBridgeTestServer creates a fully configured test server running +// a RequestBridge with sensible defaults: +// - MockRecorder (raw, unless WithWrappedRecorder) +// - NoopMCPManager (unless WithMCP) +// - slogtest debug logger (unless WithLogger) +// - DefaultTracer (unless WithTracer) +// - DefaultActorID (unless WithActor) +func NewBridgeTestServer( + t *testing.T, + ctx context.Context, + providers []aibridge.Provider, + opts ...BridgeOption, +) *BridgeTestServer { + t.Helper() + + cfg := &bridgeConfig{ + userID: DefaultActorID, + } + for _, o := range opts { + o(cfg) + } + if cfg.tracer == nil { + cfg.tracer = DefaultTracer + } + if !cfg.loggerSet { + cfg.logger = NewLogger(t) + } + if cfg.mcpProxy == nil { + cfg.mcpProxy = NewNoopMCPManager() + } + + mockRec := &testutil.MockRecorder{} + var rec aibridge.Recorder = mockRec + if cfg.wrapRecorder { + rec = aibridge.NewRecorder(cfg.logger, cfg.tracer, func() (aibridge.Recorder, error) { + return mockRec, nil + }) + } + + bridge, err := aibridge.NewRequestBridge( + ctx, providers, rec, cfg.mcpProxy, + cfg.logger, cfg.metrics, cfg.tracer, + ) + require.NoError(t, err) + + actorID, md := cfg.userID, cfg.metadata + srv := httptest.NewUnstartedServer(bridge) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, actorID, md) + } + srv.Start() + t.Cleanup(srv.Close) + + return &BridgeTestServer{ + Server: srv, + Recorder: mockRec, + Bridge: bridge, + } +} diff --git a/internal/testutil/mockmcp.go b/internal/integrationtest/mockmcp.go similarity index 99% rename from internal/testutil/mockmcp.go rename to internal/integrationtest/mockmcp.go index 212e400..67c06ec 100644 --- a/internal/testutil/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -1,4 +1,4 @@ -package testutil +package integrationtest import ( "context" diff --git a/internal/integrationtest/requests.go b/internal/integrationtest/requests.go new file mode 100644 index 0000000..92fe817 --- /dev/null +++ b/internal/integrationtest/requests.go @@ -0,0 +1,61 @@ +package integrationtest + +import ( + "bytes" + "net/http" + "testing" + + "github.com/coder/aibridge/config" + "github.com/stretchr/testify/require" +) + +// APIKey is the default API key used across integration tests. +const APIKey = "api-key" + +// CreateAnthropicMessagesReq builds an HTTP request targeting the Anthropic messages endpoint. +func CreateAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/anthropic/v1/messages", bytes.NewReader(input)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + return req +} + +// CreateOpenAIChatCompletionsReq builds an HTTP request targeting the OpenAI chat completions endpoint. +func CreateOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/chat/completions", bytes.NewReader(input)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + return req +} + +// CreateOpenAIResponsesReq builds an HTTP request targeting the OpenAI responses endpoint. +func CreateOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/responses", bytes.NewReader(input)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + return req +} + +// OpenAICfg creates a minimal OpenAI config for testing. +func OpenAICfg(url, key string) config.OpenAI { + return config.OpenAI{ + BaseURL: url, + Key: key, + } +} + +// AnthropicCfg creates a minimal Anthropic config for testing. +func AnthropicCfg(url, key string) config.Anthropic { + return config.Anthropic{ + BaseURL: url, + Key: key, + } +} diff --git a/internal/testutil/upstream.go b/internal/integrationtest/upstream.go similarity index 99% rename from internal/testutil/upstream.go rename to internal/integrationtest/upstream.go index bb935a8..356ec94 100644 --- a/internal/testutil/upstream.go +++ b/internal/integrationtest/upstream.go @@ -1,4 +1,4 @@ -package testutil +package integrationtest import ( "bufio" From 934b8167cdd2eec0a4fc99756b5b5ba33edb0a81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 15:53:15 +0000 Subject: [PATCH 02/32] refactor: migrate metrics and trace integration tests to internal/integrationtest/ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move metrics_integration_test.go → internal/integrationtest/metrics_test.go - Move trace_integration_test.go → internal/integrationtest/trace_test.go - Change package from aibridge_test to integrationtest_test - Replace all newTestSrv calls with integrationtest.NewBridgeTestServer - Delete the newTestSrv helper function entirely - Update all imports: testutil symbols → integrationtest where moved - Replace local helpers (apiKey, userID, openaiCfg, etc.) with exported versions - All former newTestSrv sites use WithWrappedRecorder() to preserve behavior - Remove duplicate testBedrockCfg (already in bridge_test.go) --- .../integrationtest/metrics_test.go | 161 ++++++++---------- .../integrationtest/trace_test.go | 131 +++++++------- 2 files changed, 142 insertions(+), 150 deletions(-) rename metrics_integration_test.go => internal/integrationtest/metrics_test.go (60%) rename trace_integration_test.go => internal/integrationtest/trace_test.go (85%) diff --git a/metrics_integration_test.go b/internal/integrationtest/metrics_test.go similarity index 60% rename from metrics_integration_test.go rename to internal/integrationtest/metrics_test.go index ac0d42a..8242221 100644 --- a/metrics_integration_test.go +++ b/internal/integrationtest/metrics_test.go @@ -1,27 +1,22 @@ -package aibridge_test +package integrationtest_test import ( "context" "io" - "net" "net/http" "net/http/httptest" "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel/trace" ) func TestMetrics_Interception(t *testing.T) { @@ -40,7 +35,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_simple", fixture: fixtures.AntSimple, - reqFunc: createAnthropicMessagesReq, + reqFunc: integrationtest.CreateAnthropicMessagesReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -49,7 +44,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_error", fixture: fixtures.AntNonStreamError, - reqFunc: createAnthropicMessagesReq, + reqFunc: integrationtest.CreateAnthropicMessagesReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -59,7 +54,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_simple", fixture: fixtures.OaiChatSimple, - reqFunc: createOpenAIChatCompletionsReq, + reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -68,7 +63,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_error", fixture: fixtures.OaiChatNonStreamError, - reqFunc: createOpenAIChatCompletionsReq, + reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -78,7 +73,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -87,7 +82,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_error", fixture: fixtures.OaiResponsesBlockingHttpErr, - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -97,7 +92,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -106,7 +101,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_error", fixture: fixtures.OaiResponsesStreamingHttpErr, - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -123,29 +118,32 @@ func TestMetrics_Interception(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) upstream.AllowOverflow = tc.allowOverflow - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + m := aibridge.NewMetrics(prometheus.NewRegistry()) var prov aibridge.Provider if tc.expectProvider == config.ProviderAnthropic { - prov = provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) + prov = provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil) } else { - prov = provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) + prov = provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) } - srv, _ := newTestSrv(t, ctx, prov, metrics, testTracer) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithWrappedRecorder(), + ) - req := tc.reqFunc(t, srv.URL, fix.Request()) + req := tc.reqFunc(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) - count := promtest.ToFloat64(metrics.InterceptionCount.WithLabelValues( - tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", userID)) + count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( + tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", integrationtest.DefaultActorID)) require.Equal(t, 1.0, count) - require.Equal(t, 1, promtest.CollectAndCount(metrics.InterceptionDuration)) - require.Equal(t, 1, promtest.CollectAndCount(metrics.InterceptionCount)) + require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionDuration)) + require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionCount)) }) } } @@ -166,15 +164,18 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { })) t.Cleanup(srv.Close) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil) - bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + prov := provider.NewAnthropic(integrationtest.AnthropicCfg(srv.URL, integrationtest.APIKey), nil) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithWrappedRecorder(), + ) // Make request in background. doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, fix.Request()) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() @@ -185,7 +186,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { // Wait until request is detected as inflight. require.Eventually(t, func() bool { return promtest.ToFloat64( - metrics.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 1 }, time.Second*10, time.Millisecond*50) @@ -200,7 +201,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { // Metric is not updated immediately after request completes, so wait until it is. require.Eventually(t, func() bool { return promtest.ToFloat64( - metrics.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 0 }, time.Second*10, time.Millisecond*50) } @@ -211,11 +212,14 @@ func TestMetrics_PassthroughCount(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) t.Cleanup(upstream.Close) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, t.Context(), provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithWrappedRecorder(), + ) - req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -223,7 +227,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - count := promtest.ToFloat64(metrics.PassthroughCount.WithLabelValues( + count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues( config.ProviderOpenAI, "/models", "GET")) require.Equal(t, 1.0, count) } @@ -235,21 +239,24 @@ func TestMetrics_PromptCount(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSimple) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithWrappedRecorder(), + ) - req := createOpenAIChatCompletionsReq(t, srv.URL, fix.Request()) + req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) - prompts := promtest.ToFloat64(metrics.PromptCount.WithLabelValues( - config.ProviderOpenAI, "gpt-4.1", userID)) + prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( + config.ProviderOpenAI, "gpt-4.1", integrationtest.DefaultActorID)) require.Equal(t, 1.0, prompts) } @@ -260,20 +267,23 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithWrappedRecorder(), + ) - req := createOpenAIChatCompletionsReq(t, srv.URL, fix.Request()) + req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) - count := promtest.ToFloat64(metrics.NonInjectedToolUseCount.WithLabelValues( + count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues( config.ProviderOpenAI, "gpt-4.1", "read_file")) require.Equal(t, 1.0, count) } @@ -286,27 +296,20 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // First request returns the tool invocation, the second returns the mocked response to the tool result. fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix), integrationtest.NewFixtureToolResponse(fix)) - recorder := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + prov := provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil) // Setup mocked MCP server & tools. - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mockMCP, logger, metrics, testTracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithMCP(mockMCP), + ) - req := createAnthropicMessagesReq(t, srv.URL, fix.Request()) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -318,35 +321,13 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) + recorder := ts.Recorder require.Len(t, recorder.ToolUsages(), 1) require.True(t, recorder.ToolUsages()[0].Injected) require.NotNil(t, recorder.ToolUsages()[0].ServerURL) actualServerURL := *recorder.ToolUsages()[0].ServerURL - count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues( - config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.MockToolName)) + count := promtest.ToFloat64(m.InjectedToolUseCount.WithLabelValues( + config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, integrationtest.MockToolName)) require.Equal(t, 1.0, count) } - -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *metrics.Metrics, tracer trace.Tracer) (*httptest.Server, *testutil.MockRecorder) { - t.Helper() - - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - mockRecorder := &testutil.MockRecorder{} - clientFn := func() (aibridge.Recorder, error) { - return mockRecorder, nil - } - wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, testutil.NewNoopMCPManager(), logger, metrics, tracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) - - return srv, mockRecorder -} diff --git a/trace_integration_test.go b/internal/integrationtest/trace_test.go similarity index 85% rename from trace_integration_test.go rename to internal/integrationtest/trace_test.go index a62e58e..41dd968 100644 --- a/trace_integration_test.go +++ b/internal/integrationtest/trace_test.go @@ -1,4 +1,4 @@ -package aibridge_test +package integrationtest_test import ( "context" @@ -9,9 +9,10 @@ import ( "testing" "time" + "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" @@ -98,25 +99,29 @@ func TestTraceAnthropic(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) var bedrockCfg *config.AWSBedrock if tc.bedrock { bedrockCfg = testBedrockCfg(upstream.URL) } - provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) - srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + prov := provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), bedrockCfg) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithTracer(tracer), + integrationtest.WithWrappedRecorder(), + ) reqBody, err := sjson.SetBytes(fixtureReqBody, "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, srv.URL, reqBody) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() - srv.Close() + ts.Close() + recorder := ts.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -135,7 +140,7 @@ func TestTraceAnthropic(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderAnthropic), attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), attribute.Bool(tracing.IsBedrock, tc.bedrock), } @@ -214,18 +219,21 @@ func TestTraceAnthropicErr(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) var bedrockCfg *config.AWSBedrock if tc.bedrock { bedrockCfg = testBedrockCfg(upstream.URL) } - provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) - srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + prov := provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), bedrockCfg) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithTracer(tracer), + integrationtest.WithWrappedRecorder(), + ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, srv.URL, reqBody) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -235,8 +243,9 @@ func TestTraceAnthropicErr(t *testing.T) { require.Equal(t, tc.expectCode, resp.StatusCode) } defer resp.Body.Close() - srv.Close() + ts.Close() + recorder := ts.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -259,7 +268,7 @@ func TestTraceAnthropicErr(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderAnthropic), attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), attribute.Bool(tracing.IsBedrock, tc.bedrock), } @@ -288,7 +297,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: false, fixture: fixtures.AntSingleInjectedTool, providerFn: newAnthropicProvider, - createReqFn: createAnthropicMessagesReq, + createReqFn: integrationtest.CreateAnthropicMessagesReq, expectModel: "claude-sonnet-4-20250514", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -298,7 +307,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: true, fixture: fixtures.AntSingleInjectedTool, providerFn: newAnthropicProvider, - createReqFn: createAnthropicMessagesReq, + createReqFn: integrationtest.CreateAnthropicMessagesReq, expectModel: "claude-sonnet-4-20250514", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -309,7 +318,7 @@ func TestInjectedToolsTrace(t *testing.T) { bedrock: true, fixture: fixtures.AntSingleInjectedTool, providerFn: newBedrockProvider, - createReqFn: createAnthropicMessagesReq, + createReqFn: integrationtest.CreateAnthropicMessagesReq, expectModel: "beddel", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -320,7 +329,7 @@ func TestInjectedToolsTrace(t *testing.T) { bedrock: true, fixture: fixtures.AntSingleInjectedTool, providerFn: newBedrockProvider, - createReqFn: createAnthropicMessagesReq, + createReqFn: integrationtest.CreateAnthropicMessagesReq, expectModel: "beddel", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -330,7 +339,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: false, fixture: fixtures.OaiChatSingleInjectedTool, providerFn: newOpenAIProvider, - createReqFn: createOpenAIChatCompletionsReq, + createReqFn: integrationtest.CreateOpenAIChatCompletionsReq, expectModel: "gpt-4.1", expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, @@ -340,7 +349,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: true, fixture: fixtures.OaiChatSingleInjectedTool, providerFn: newOpenAIProvider, - createReqFn: createOpenAIChatCompletionsReq, + createReqFn: integrationtest.CreateOpenAIChatCompletionsReq, expectModel: "gpt-4.1", expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, @@ -364,7 +373,7 @@ func TestInjectedToolsTrace(t *testing.T) { } recorderClient, mockMCP, resp := setupInjectedToolTest( - t, tc.fixture, tc.streaming, tc.providerFn, tracer, userID, + t, tc.fixture, tc.streaming, tc.providerFn, tracer, integrationtest.DefaultActorID, tc.createReqFn, validatorFn, ) defer resp.Body.Close() @@ -379,7 +388,7 @@ func TestInjectedToolsTrace(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, tc.expectProvider), attribute.String(tracing.Model, tc.expectModel), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), attribute.String(tracing.MCPInput, `{"owner":"admin"}`), attribute.String(tracing.MCPToolName, "coder_list_workspaces"), attribute.String(tracing.MCPServerName, tool.ServerName), @@ -409,7 +418,7 @@ func TestTraceOpenAI(t *testing.T) { fixture: fixtures.OaiChatSimple, streaming: true, expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -424,7 +433,7 @@ func TestTraceOpenAI(t *testing.T) { { name: "trace_openai_chat_blocking", fixture: fixtures.OaiChatSimple, - reqFunc: createOpenAIChatCompletionsReq, + reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, streaming: false, expectPath: "/openai/v1/chat/completions", expect: []expectTrace{ @@ -443,7 +452,7 @@ func TestTraceOpenAI(t *testing.T) { fixture: fixtures.OaiResponsesStreamingSimple, streaming: true, expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -460,7 +469,7 @@ func TestTraceOpenAI(t *testing.T) { fixture: fixtures.OaiResponsesBlockingSimple, streaming: false, expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -485,20 +494,24 @@ func TestTraceOpenAI(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithTracer(tracer), + integrationtest.WithWrappedRecorder(), + ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.reqFunc(t, srv.URL, reqBody) + req := tc.reqFunc(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() - srv.Close() + ts.Close() + recorder := ts.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -513,7 +526,7 @@ func TestTraceOpenAI(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -537,7 +550,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiChatMidStreamError, streaming: true, expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -554,7 +567,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiChatNonStreamError, streaming: false, expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, expectCode: http.StatusBadRequest, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -570,7 +583,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: true, fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -587,7 +600,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, streaming: false, expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, // Fixture returns http 200 response with wrong body // responses forward received response as is so // expected code == 200 even though ProcessRequest @@ -609,7 +622,7 @@ func TestTraceOpenAIErr(t *testing.T) { allowOverflow: true, // 429 error causes retries expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectCode: http.StatusTooManyRequests, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -626,7 +639,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: false, expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + reqFunc: integrationtest.CreateOpenAIResponsesReq, expectCode: http.StatusUnauthorized, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -651,22 +664,26 @@ func TestTraceOpenAIErr(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) - mockAPI := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + mockAPI := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) mockAPI.AllowOverflow = tc.allowOverflow - prov := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) - srv, recorder := newTestSrv(t, ctx, prov, nil, tracer) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(mockAPI.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithTracer(tracer), + integrationtest.WithWrappedRecorder(), + ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.reqFunc(t, srv.URL, reqBody) + req := tc.reqFunc(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, tc.expectCode, resp.StatusCode) defer resp.Body.Close() - srv.Close() + ts.Close() + recorder := ts.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -681,7 +698,7 @@ func TestTraceOpenAIErr(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -694,24 +711,27 @@ func TestTracePassthrough(t *testing.T) { fix := fixtures.Parse(t, fixtures.OaiChatFallthrough) - upstream := testutil.NewMockUpstream(t, t.Context(), testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, t.Context(), integrationtest.NewFixtureResponse(fix)) sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, t.Context(), provider, nil, tracer) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, + integrationtest.WithTracer(tracer), + integrationtest.WithWrappedRecorder(), + ) - req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) + req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - srv.Close() + ts.Close() spans := sr.Ended() require.Len(t, spans, 1) @@ -733,7 +753,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() serverName := "serverName" - mockMCP := testutil.SetupMCPForTestWithName(t, serverName, tracer) + mockMCP := integrationtest.SetupMCPForTestWithName(t, serverName, tracer) tool := mockMCP.ListTools()[0] require.Len(t, sr.Ended(), 3) @@ -776,13 +796,4 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e } } -func testBedrockCfg(url string) *config.AWSBedrock { - return &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - BaseURL: url, - } -} + From 6bd29e0faaec5f8c5cf77a5074bdb8e1199dad9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 15:56:22 +0000 Subject: [PATCH 03/32] refactor: migrate responses, circuit_breaker, and apidump integration tests to internal/integrationtest/ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the following files from the repo root: - responses_integration_test.go → internal/integrationtest/responses_test.go - circuit_breaker_integration_test.go → internal/integrationtest/circuit_breaker_test.go - apidump_integration_test.go → internal/integrationtest/apidump_test.go Changes in all files: - Package: aibridge_test → integrationtest_test - Replace newTestSrv/NewRequestBridge with integrationtest.NewBridgeTestServer - Replace testutil.* moved symbols with integrationtest.* - Replace local helper refs (openaiCfg, apiKey, userID, etc.) with integrationtest.OpenAICfg, integrationtest.APIKey, integrationtest.DefaultActorID - Delete local createOpenAIResponsesReq (now in requests.go) - Keep local helpers specific to each file --- .../integrationtest/apidump_test.go | 62 +++----- .../integrationtest/circuit_breaker_test.go | 133 +++++------------- .../integrationtest/responses_test.go | 93 +++++------- 3 files changed, 86 insertions(+), 202 deletions(-) rename apidump_integration_test.go => internal/integrationtest/apidump_test.go (78%) rename circuit_breaker_integration_test.go => internal/integrationtest/circuit_breaker_test.go (83%) rename responses_integration_test.go => internal/integrationtest/responses_test.go (91%) diff --git a/apidump_integration_test.go b/internal/integrationtest/apidump_test.go similarity index 78% rename from apidump_integration_test.go rename to internal/integrationtest/apidump_test.go index 8aac244..fd13883 100644 --- a/apidump_integration_test.go +++ b/internal/integrationtest/apidump_test.go @@ -1,11 +1,10 @@ -package aibridge_test +package integrationtest_test import ( "bufio" "bytes" "context" "io" - "net" "net/http" "net/http/httptest" "os" @@ -14,14 +13,11 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept/apidump" - "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/provider" "github.com/stretchr/testify/require" ) @@ -55,25 +51,25 @@ func TestAPIDump(t *testing.T) { name: "anthropic", fixture: fixtures.AntSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} + return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir), nil)} }, - createRequestFunc: createAnthropicMessagesReq, + createRequestFunc: integrationtest.CreateAnthropicMessagesReq, }, { name: "openai_chat_completions", fixture: fixtures.OaiChatSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir))} }, - createRequestFunc: createOpenAIChatCompletionsReq, + createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, }, { name: "openai_responses", fixture: fixtures.OaiResponsesBlockingSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir))} }, - createRequestFunc: createOpenAIResponsesReq, + createRequestFunc: integrationtest.CreateOpenAIResponsesReq, }, } @@ -81,30 +77,20 @@ func TestAPIDump(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) // Setup mock upstream server. fix := fixtures.Parse(t, tc.fixture) - srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + srv := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) // Create temp dir for API dumps. dumpDir := t.TempDir() - recorderClient := &testutil.MockRecorder{} - b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + providers := tc.providersFunc(srv.URL, dumpDir) + ts := integrationtest.NewBridgeTestServer(t, ctx, providers) - req := tc.createRequestFunc(t, mockSrv.URL, fix.Request()) + req := tc.createRequestFunc(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -112,7 +98,7 @@ func TestAPIDump(t *testing.T) { _, _ = io.ReadAll(resp.Body) // Verify dump files were created. - interceptions := recorderClient.RecordedInterceptions() + interceptions := ts.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) interceptionID := interceptions[0].ID @@ -167,7 +153,7 @@ func TestAPIDump(t *testing.T) { expectedRespBody := fix.NonStreaming() require.JSONEq(t, string(expectedRespBody), string(dumpRespBody), "response body JSON should match semantically") - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } } @@ -186,7 +172,7 @@ func TestAPIDumpPassthrough(t *testing.T) { { name: "anthropic", providerFunc: func(addr string, dumpDir string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) + return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir), nil) }, requestPath: "/anthropic/v1/models", expectDumpName: "-v1-models-", @@ -194,7 +180,7 @@ func TestAPIDumpPassthrough(t *testing.T) { { name: "openai", providerFunc: func(addr string, dumpDir string) aibridge.Provider { - return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir)) }, requestPath: "/openai/v1/models", expectDumpName: "-models-", @@ -213,8 +199,6 @@ func TestAPIDumpPassthrough(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) @@ -226,20 +210,10 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() - recorderClient := &testutil.MockRecorder{} prov := tc.providerFunc(upstream.URL, dumpDir) - provs := []aibridge.Provider{prov} - b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - bridgeSrv := httptest.NewUnstartedServer(b) - t.Cleanup(bridgeSrv.Close) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, bridgeSrv.URL+tc.requestPath, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+tc.requestPath, nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) diff --git a/circuit_breaker_integration_test.go b/internal/integrationtest/circuit_breaker_test.go similarity index 83% rename from circuit_breaker_integration_test.go rename to internal/integrationtest/circuit_breaker_test.go index 643b868..61bd01f 100644 --- a/circuit_breaker_integration_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -1,10 +1,8 @@ -package aibridge_test +package integrationtest_test import ( - "context" "fmt" "io" - "net" "net/http" "net/http/httptest" "strings" @@ -13,18 +11,15 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" ) // Common response bodies for circuit breaker tests. @@ -72,7 +67,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -92,7 +87,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -127,7 +122,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { })) defer mockUpstream.Close() - metrics := metrics.NewMetrics(prometheus.NewRegistry()) + m := metrics.NewMetrics(prometheus.NewRegistry()) // Create provider with circuit breaker config cbConfig := &config.CircuitBreaker{ @@ -139,27 +134,13 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - metrics, - tracer, + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithActor("test-user-id", nil), ) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() makeRequest := func() *http.Response { - req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) + req := tc.createRequest(t, ts.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -184,13 +165,13 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") // Verify metrics show circuit is open - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") // Phase 3: Wait for timeout to transition to half-open @@ -206,7 +187,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Verify circuit is now closed - state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") // Phase 5: Verify circuit is fully functional again @@ -220,7 +201,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") // Rejects count should not have increased - rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects = promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") }) } @@ -255,7 +236,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -274,7 +255,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -301,7 +282,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { })) defer mockUpstream.Close() - metrics := metrics.NewMetrics(prometheus.NewRegistry()) + m := metrics.NewMetrics(prometheus.NewRegistry()) cbConfig := &config.CircuitBreaker{ FailureThreshold: 2, @@ -312,27 +293,13 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - metrics, - tracer, + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithActor("test-user-id", nil), ) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() makeRequest := func() *http.Response { - req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) + req := tc.createRequest(t, ts.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -352,7 +319,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { resp := makeRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") // Phase 2: Wait for half-open state @@ -370,10 +337,10 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") // Verify metrics: trips should be 2 now (tripped twice) - trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips = promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") }) } @@ -410,7 +377,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -430,7 +397,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -466,7 +433,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { })) defer mockUpstream.Close() - metrics := metrics.NewMetrics(prometheus.NewRegistry()) + m := metrics.NewMetrics(prometheus.NewRegistry()) const maxRequests = 2 cbConfig := &config.CircuitBreaker{ @@ -478,27 +445,13 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - metrics, - tracer, + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithActor("test-user-id", nil), ) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() makeRequest := func() *http.Response { - req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) + req := tc.createRequest(t, ts.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -562,7 +515,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { "%d requests should be rejected (ErrTooManyRequests)", totalRequests-maxRequests) // Verify rejects metric increased - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, float64(1+totalRequests-maxRequests), rejects, "CircuitBreakerRejects should include half-open rejections") }) @@ -616,28 +569,14 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { }, nil) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - m, - tracer, + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + integrationtest.WithMetrics(m), + integrationtest.WithActor("test-user-id", nil), ) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() makeRequest := func(model string) *http.Response { body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - req := createAnthropicMessagesReq(t, mockSrv.URL, []byte(body)) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, []byte(body)) req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") resp, err := http.DefaultClient.Do(req) diff --git a/responses_integration_test.go b/internal/integrationtest/responses_test.go similarity index 91% rename from responses_integration_test.go rename to internal/integrationtest/responses_test.go index 7a26fff..61ce20f 100644 --- a/responses_integration_test.go +++ b/internal/integrationtest/responses_test.go @@ -1,7 +1,6 @@ -package aibridge_test +package integrationtest_test import ( - "bytes" "context" "encoding/json" "io" @@ -14,12 +13,10 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" @@ -335,16 +332,14 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - ctx = aibcontext.AsActor(ctx, userID, nil) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, mockRecorder := newTestSrv(t, ctx, provider, nil, testTracer) - defer srv.Close() + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, integrationtest.WithWrappedRecorder()) - req := createOpenAIResponsesReq(t, srv.URL, fix.Request()) + req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, fix.Request()) req.Header.Set("User-Agent", tc.userAgent) client := &http.Client{} @@ -361,16 +356,16 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Equal(t, string(fix.NonStreaming()), string(got)) } - interceptions := mockRecorder.RecordedInterceptions() + interceptions := ts.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) intc := interceptions[0] - require.Equal(t, intc.InitiatorID, userID) + require.Equal(t, intc.InitiatorID, integrationtest.DefaultActorID) require.Equal(t, intc.Provider, config.ProviderOpenAI) require.Equal(t, intc.Model, tc.expectModel) require.Equal(t, tc.userAgent, intc.UserAgent) require.Equal(t, string(tc.expectedClient), intc.Client) - recordedPrompts := mockRecorder.RecordedPromptUsages() + recordedPrompts := ts.Recorder.RecordedPromptUsages() if tc.expectPromptRecorded != "" { require.Len(t, recordedPrompts, 1) promptEq := func(pur *recorder.PromptUsageRecord) bool { return pur.Prompt == tc.expectPromptRecorded } @@ -379,7 +374,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Empty(t, recordedPrompts) } - recordedTools := mockRecorder.RecordedToolUsages() + recordedTools := ts.Recorder.RecordedToolUsages() if tc.expectToolRecorded != nil { require.Len(t, recordedTools, 1) recordedTools[0].InterceptionID = tc.expectToolRecorded.InterceptionID // ignore interception id (interception id is not constant and response doesn't contain it) @@ -389,7 +384,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Empty(t, recordedTools) } - recordedTokens := mockRecorder.RecordedTokenUsages() + recordedTokens := ts.Recorder.RecordedTokenUsages() if tc.expectTokenUsage != nil { require.Len(t, recordedTokens, 1) recordedTokens[0].InterceptionID = tc.expectTokenUsage.InterceptionID // ignore interception id @@ -433,13 +428,12 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) - req := createOpenAIResponsesReq(t, srv.URL, reqBytes) + req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -547,11 +541,10 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) - req := createOpenAIResponsesReq(t, srv.URL, []byte(tc.request)) + req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, []byte(tc.request)) client := &http.Client{} resp, err := client.Do(req) @@ -608,12 +601,11 @@ func TestClientAndConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - prov := provider.NewOpenAI(openaiCfg(tc.addr, apiKey)) - srv, mockRecorder := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + prov := provider.NewOpenAI(integrationtest.OpenAICfg(tc.addr, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, integrationtest.WithWrappedRecorder()) reqBytes := responsesRequestBytes(t, tc.streaming) - req := createOpenAIResponsesReq(t, srv.URL, reqBytes) + req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -626,7 +618,7 @@ func TestClientAndConnectionError(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) requireResponsesError(t, http.StatusInternalServerError, tc.errContains, body) - require.Empty(t, mockRecorder.RecordedPromptUsages()) + require.Empty(t, ts.Recorder.RecordedPromptUsages()) }) } } @@ -692,12 +684,11 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) reqBytes := responsesRequestBytes(t, tc.streaming) - req := createOpenAIResponsesReq(t, srv.URL, reqBytes) + req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -868,29 +859,18 @@ func TestResponsesInjectedTool(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix), integrationtest.NewFixtureToolResponse(fix)) // Setup MCP server proxies (with mock tools). - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) if tc.expectToolError != "" { mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError) } - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - mockRecorder := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, integrationtest.WithMCP(mockMCP)) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mockMCP, logger, nil, testTracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) - - req := createOpenAIResponsesReq(t, srv.URL, fix.Request()) + req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -909,7 +889,7 @@ func TestResponsesInjectedTool(t *testing.T) { require.Len(t, invocations, 1, "expected MCP tool to be invoked once") // Verify the injected tool usage was recorded. - toolUsages := mockRecorder.RecordedToolUsages() + toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) require.Equal(t, tc.mcpToolName, toolUsages[0].Tool) require.Equal(t, tc.expectToolArgs, toolUsages[0].Args) @@ -919,11 +899,11 @@ func TestResponsesInjectedTool(t *testing.T) { } // Verify prompt was recorded. - prompts := mockRecorder.RecordedPromptUsages() + prompts := ts.Recorder.RecordedPromptUsages() require.Len(t, prompts, 1) require.Equal(t, tc.expectPrompt, prompts[0].Prompt) - tokenUsages := mockRecorder.RecordedTokenUsages() + tokenUsages := ts.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, len(tc.expectTokenUsages)) for i := range tokenUsages { tokenUsages[i].InterceptionID = "" // ignore interception ID and time creation when comparing @@ -941,15 +921,6 @@ func TestResponsesInjectedTool(t *testing.T) { } } -func createOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/responses", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - return req -} - func requireResponsesError(t *testing.T, code int, message string, body []byte) { var respErr responses.Error err := json.Unmarshal(body, &respErr) From cd91c6cee9fd5f078beef32098b529371a29d4ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 16:08:57 +0000 Subject: [PATCH 04/32] Finish bridge_test.go migration: replace all configureFunc/inline boilerplate with NewBridgeTestServer - Convert all 18 remaining aibridge.NewRequestBridge call sites to use NewBridgeTestServer with functional options - Replace all configureFunc closures (4 different signatures, ~12 inline closures) with providerFn field + NewBridgeTestServer in test body - Convert setupInjectedToolTest to use NewBridgeTestServer internally - Fix stale recorderClient references to ts.Recorder - Remove unused imports (slog, aibcontext, testutil from responses_test) - Delete bridge_integration_test.go from root Net result: -424 lines of boilerplate, root directory down to 6 files. All tests pass (go test ./... -count=1). --- .../integrationtest/bridge_test.go | 615 ++++++------------ internal/integrationtest/responses_test.go | 1 - 2 files changed, 192 insertions(+), 424 deletions(-) rename bridge_integration_test.go => internal/integrationtest/bridge_test.go (70%) diff --git a/bridge_integration_test.go b/internal/integrationtest/bridge_test.go similarity index 70% rename from bridge_integration_test.go rename to internal/integrationtest/bridge_test.go index 5eed920..c52f944 100644 --- a/bridge_integration_test.go +++ b/internal/integrationtest/bridge_test.go @@ -1,4 +1,4 @@ -package aibridge_test +package integrationtest_test import ( "bytes" @@ -13,16 +13,15 @@ import ( "testing" "time" - "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" @@ -34,17 +33,21 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "go.uber.org/goleak" ) -var testTracer = otel.Tracer("forTesting") - -const ( - apiKey = "api-key" - userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" -) +// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. +func testBedrockCfg(url string) *config.AWSBedrock { + return &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + BaseURL: url, + } +} type ( providerFunc func(addr string) aibridge.Provider @@ -52,15 +55,15 @@ type ( ) func newAnthropicProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) + return provider.NewAnthropic(integrationtest.AnthropicCfg(addr, integrationtest.APIKey), nil) } func newOpenAIProvider(addr string) aibridge.Provider { - return provider.NewOpenAI(openaiCfg(addr, apiKey)) + return provider.NewOpenAI(integrationtest.OpenAICfg(addr, integrationtest.APIKey)) } func newBedrockProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) + return provider.NewAnthropic(integrationtest.AnthropicCfg(addr, integrationtest.APIKey), testBedrockCfg(addr)) } func TestMain(m *testing.M) { @@ -101,25 +104,16 @@ func TestAnthropicMessages(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil)}, + ) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -141,13 +135,13 @@ func TestAnthropicMessages(t *testing.T) { // One for message_start, one for message_delta. expectedTokenRecordings = 2 } - tokenUsages := recorderClient.RecordedTokenUsages() + tokenUsages := ts.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, expectedTokenRecordings) assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "Read", toolUsages[0].Tool) assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) @@ -157,11 +151,11 @@ func TestAnthropicMessages(t *testing.T) { require.Contains(t, args, "file_path") assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := ts.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -185,21 +179,12 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ - provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg("http://unused", integrationtest.APIKey), bedrockCfg)}, + integrationtest.WithLogger(integrationtest.NewLogger(t, &slogtest.Options{IgnoreErrors: true})), + ) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() - - req := createAnthropicMessagesReq(t, mockSrv.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -220,7 +205,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ @@ -232,25 +217,16 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge( - ctx, []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, - recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - mockBridgeSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockBridgeSrv.Close) - mockBridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockBridgeSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), bedrockCfg)}, + integrationtest.WithLogger(integrationtest.NewLogger(t, &slogtest.Options{IgnoreErrors: true})), + ) // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, mockBridgeSrv.URL, reqBody) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -277,10 +253,10 @@ func TestAWSBedrockIntegration(t *testing.T) { require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - interceptions := recorderClient.RecordedInterceptions() + interceptions := ts.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -319,25 +295,16 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey))}, + ) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) + req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -359,12 +326,12 @@ func TestOpenAIChatCompletions(t *testing.T) { assert.Equal(t, "[DONE]", lastEvent.Data) } - tokenUsages := recorderClient.RecordedTokenUsages() + tokenUsages := ts.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, 1) assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "read_file", toolUsages[0].Tool) assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) @@ -372,11 +339,11 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Contains(t, toolUsages[0].Args, "path") assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := ts.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -411,29 +378,20 @@ func TestOpenAIChatCompletions(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix), integrationtest.NewFixtureToolResponse(fix)) // Setup MCP proxies with the tool from the fixture - mockMCP := testutil.SetupMCPForTest(t, testTracer) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mockMCP, logger, nil, testTracer) - require.NoError(t, err) + mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey))}, + integrationtest.WithMCP(mockMCP), + ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) + req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -451,7 +409,7 @@ func TestOpenAIChatCompletions(t *testing.T) { resp.Body.Close() // Verify the MCP tool was actually invoked - invocations := mockMCP.GetCallsByTool(testutil.MockToolName) + invocations := mockMCP.GetCallsByTool(integrationtest.MockToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked") // Verify tool was invoked with the expected args (if specified) @@ -464,11 +422,11 @@ func TestOpenAIChatCompletions(t *testing.T) { } // Verify tool usage was recorded - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) - assert.Equal(t, testutil.MockToolName, toolUsages[0].Tool) + assert.Equal(t, integrationtest.MockToolName, toolUsages[0].Tool) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -535,29 +493,14 @@ func TestSimple(t *testing.T) { return message.ID, nil } - // Common configuration functions for each provider type. - configureAnthropic := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - } - - configureOpenAI := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - } - testCases := []struct { name string fixture []byte basePath string expectedPath string - configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error) + providerFn providerFunc getResponseIDFunc func(streaming bool, resp *http.Response) (string, error) - createRequest func(*testing.T, string, []byte) *http.Request + createRequest createRequestFunc expectedMsgID string userAgent string expectedClient aibridge.Client @@ -567,9 +510,9 @@ func TestSimple(t *testing.T) { fixture: fixtures.AntSimple, basePath: "", expectedPath: "/v1/messages", - configureFunc: configureAnthropic, + providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "claude-cli/2.0.67 (external, cli)", expectedClient: aibridge.ClientClaudeCode, @@ -579,9 +522,9 @@ func TestSimple(t *testing.T) { fixture: fixtures.OaiChatSimple, basePath: "", expectedPath: "/chat/completions", - configureFunc: configureOpenAI, + providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", expectedClient: aibridge.ClientCodex, @@ -591,9 +534,9 @@ func TestSimple(t *testing.T) { fixture: fixtures.AntSimple, basePath: "/api", expectedPath: "/api/v1/messages", - configureFunc: configureAnthropic, + providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "GitHubCopilotChat/0.37.2026011603", expectedClient: aibridge.ClientCopilotVSC, @@ -603,9 +546,9 @@ func TestSimple(t *testing.T) { fixture: fixtures.OaiChatSimple, basePath: "/api", expectedPath: "/api/chat/completions", - configureFunc: configureOpenAI, + providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", expectedClient: aibridge.ClientZed, @@ -624,24 +567,16 @@ func TestSimple(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - b, err := tc.configureFunc(t, upstream.URL+tc.basePath, recorderClient) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{tc.providerFn(upstream.URL + tc.basePath)}, + ) // When: calling the "API server" with the fixture's request body. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := tc.createRequest(t, mockSrv.URL, reqBody) + req := tc.createRequest(t, ts.URL, reqBody) req.Header.Set("User-Agent", tc.userAgent) client := &http.Client{} resp, err := client.Do(req) @@ -663,7 +598,7 @@ func TestSimple(t *testing.T) { resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Then: I expect the prompt to have been tracked. - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := ts.Recorder.RecordedPromptUsages() require.NotEmpty(t, promptUsages, "no prompts tracked") assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") @@ -674,17 +609,17 @@ func TestSimple(t *testing.T) { require.NoError(t, err, "failed to retrieve response ID") require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) - tokenUsages := recorderClient.RecordedTokenUsages() + tokenUsages := ts.Recorder.RecordedTokenUsages() require.GreaterOrEqual(t, len(tokenUsages), 1) require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) // Validate user agent and client have been recorded. - interceptions := recorderClient.RecordedInterceptions() + interceptions := ts.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -822,7 +757,7 @@ func TestFallthrough(t *testing.T) { basePath string requestPath string expectedUpstreamPath string - configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) + providerFn providerFunc }{ { name: "ant_empty_base_url_path", @@ -831,13 +766,7 @@ func TestFallthrough(t *testing.T) { basePath: "", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/v1/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + providerFn: newAnthropicProvider, }, { name: "oai_empty_base_url_path", @@ -846,13 +775,7 @@ func TestFallthrough(t *testing.T) { basePath: "", requestPath: "/openai/v1/models", expectedUpstreamPath: "/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + providerFn: newOpenAIProvider, }, { name: "ant_some_base_url_path", @@ -861,13 +784,7 @@ func TestFallthrough(t *testing.T) { basePath: "/api", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/api/v1/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + providerFn: newAnthropicProvider, }, { name: "oai_some_base_url_path", @@ -876,13 +793,7 @@ func TestFallthrough(t *testing.T) { basePath: "/api", requestPath: "/openai/v1/models", expectedUpstreamPath: "/api/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + providerFn: newOpenAIProvider, }, } @@ -891,18 +802,13 @@ func TestFallthrough(t *testing.T) { t.Parallel() fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, t.Context(), testutil.NewFixtureResponse(fix)) - recorderClient := &testutil.MockRecorder{} - provider, bridge := tc.configureFunc(upstream.URL+tc.basePath, recorderClient) - - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(t.Context(), userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + upstream := integrationtest.NewMockUpstream(t, t.Context(), integrationtest.NewFixtureResponse(fix)) + p := tc.providerFn(upstream.URL + tc.basePath) + ts := integrationtest.NewBridgeTestServer(t, t.Context(), + []aibridge.Provider{p}, + ) - req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s%s", bridgeSrv.URL, tc.requestPath), nil) + req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s%s", ts.URL, tc.requestPath), nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) @@ -916,7 +822,7 @@ func TestFallthrough(t *testing.T) { received := upstream.ReceivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(provider.AuthHeader()), apiKey) + require.Contains(t, received[0].Header.Get(p.AuthHeader()), integrationtest.APIKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -939,18 +845,18 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, testTracer, userID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, integrationtest.DefaultTracer, integrationtest.DefaultActorID, integrationtest.CreateAnthropicMessagesReq, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, testutil.MockToolName, toolUsages[0].Tool) + require.Equal(t, integrationtest.MockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(testutil.MockToolName) + invocations := mockMCP.GetCallsByTool(integrationtest.MockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1023,18 +929,18 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, testTracer, userID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, integrationtest.DefaultTracer, integrationtest.DefaultActorID, integrationtest.CreateOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, testutil.MockToolName, toolUsages[0].Tool) + require.Equal(t, integrationtest.MockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(testutil.MockToolName) + invocations := mockMCP.GetCallsByTool(integrationtest.MockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1194,10 +1100,10 @@ func setupInjectedToolTest( streaming bool, providerFn providerFunc, tracer trace.Tracer, - userID string, + actorID string, createRequestFn func(*testing.T, string, []byte) *http.Request, toolRequestValidatorFn func(*http.Request, []byte), -) (*testutil.MockRecorder, *testutil.MockMCP, *http.Response) { +) (*testutil.MockRecorder, *integrationtest.MockMCP, *http.Response) { t.Helper() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -1207,39 +1113,25 @@ func setupInjectedToolTest( // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. - firstResp := testutil.NewFixtureResponse(fix) - toolResp := testutil.NewFixtureToolResponse(fix) + firstResp := integrationtest.NewFixtureResponse(fix) + toolResp := integrationtest.NewFixtureToolResponse(fix) toolResp.OnRequest = toolRequestValidatorFn - upstream := testutil.NewMockUpstream(t, ctx, firstResp, toolResp) + upstream := integrationtest.NewMockUpstream(t, ctx, firstResp, toolResp) - recorderClient := &testutil.MockRecorder{} + mockMCP := integrationtest.SetupMCPForTest(t, tracer) - mockMCP := testutil.SetupMCPForTest(t, tracer) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge( - t.Context(), + ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{providerFn(upstream.URL)}, - recorderClient, - mockMCP, - logger, - nil, - tracer, + integrationtest.WithMCP(mockMCP), + integrationtest.WithTracer(tracer), + integrationtest.WithActor(actorID, nil), ) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := createRequestFn(t, bridgeSrv.URL, reqBody) + req := createRequestFn(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1253,7 +1145,7 @@ func setupInjectedToolTest( return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, mockMCP, resp + return ts.Recorder, mockMCP, resp } func TestErrorHandling(t *testing.T) { @@ -1265,18 +1157,14 @@ func TestErrorHandling(t *testing.T) { name string fixture []byte createRequestFunc createRequestFunc - configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) + providerFn providerFunc responseHandlerFn func(resp *http.Response) }{ { name: config.ProviderAnthropic, fixture: fixtures.AntNonStreamError, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + providerFn: newAnthropicProvider, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -1289,12 +1177,8 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatNonStreamError, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + providerFn: newOpenAIProvider, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -1320,32 +1204,23 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. Error fixtures contain raw HTTP // responses that may cause the bridge to retry. fix := fixtures.Parse(t, tc.fixture) - mockSrv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - b, err := tc.configureFunc(mockSrv.URL, recorderClient, testutil.NewNoopMCPManager()) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{tc.providerFn(upstream.URL)}, + ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) + req := tc.createRequestFunc(t, ts.URL, reqBody) resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) tc.responseHandlerFn(resp) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1358,18 +1233,14 @@ func TestErrorHandling(t *testing.T) { name string fixture []byte createRequestFunc createRequestFunc - configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) + providerFn providerFunc responseHandlerFn func(resp *http.Response) }{ { name: config.ProviderAnthropic, fixture: fixtures.AntMidStreamError, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + providerFn: newAnthropicProvider, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1383,12 +1254,8 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatMidStreamError, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + providerFn: newOpenAIProvider, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1415,30 +1282,21 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - recorderClient := &testutil.MockRecorder{} - - b, err := tc.configureFunc(upstream.URL, recorderClient, testutil.NewNoopMCPManager()) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{tc.providerFn(upstream.URL)}, + ) - req := tc.createRequestFunc(t, bridgeSrv.URL, fix.Request()) + req := tc.createRequestFunc(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) - bridgeSrv.Close() + ts.Close() tc.responseHandlerFn(resp) - recorderClient.VerifyAllInterceptionsEnded(t) + ts.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1451,31 +1309,23 @@ func TestErrorHandling(t *testing.T) { func TestStableRequestEncoding(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - cases := []struct { name string fixture []byte createRequestFunc createRequestFunc - configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) + providerFn providerFunc }{ { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + providerFn: newAnthropicProvider, }, { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + providerFn: newOpenAIProvider, }, } @@ -1487,33 +1337,26 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) fix := fixtures.Parse(t, tc.fixture) // Create a mock upstream that serves the same blocking response for each request. count := 10 - responses := make([]testutil.UpstreamResponse, count) + responses := make([]integrationtest.UpstreamResponse, count) for i := range count { - responses[i] = testutil.NewFixtureResponse(fix) + responses[i] = integrationtest.NewFixtureResponse(fix) } - upstream := testutil.NewMockUpstream(t, ctx, responses...) - - recorder := &testutil.MockRecorder{} - bridge, err := tc.configureFunc(upstream.URL, recorder, mockMCP) - require.NoError(t, err) + upstream := integrationtest.NewMockUpstream(t, ctx, responses...) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{tc.providerFn(upstream.URL)}, + integrationtest.WithMCP(mockMCP), + ) // Make multiple requests and verify they all have identical payloads. for range count { - req := tc.createRequestFunc(t, bridgeSrv.URL, fix.Request()) + req := tc.createRequestFunc(t, ts.URL, fix.Request()) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1615,33 +1458,24 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { // Setup MCP tools conditionally. var mcpMgr mcp.ServerProxier if tc.withInjectedTools { - mcpMgr = testutil.SetupMCPForTest(t, testTracer) + mcpMgr = integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) } else { - mcpMgr = testutil.NewNoopMCPManager() + mcpMgr = integrationtest.NewNoopMCPManager() } fix := fixtures.Parse(t, fixtures.AntSimple) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - recorder := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil)}, + integrationtest.WithMCP(mcpMgr), + ) // Prepare request body with tool_choice set. reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1691,20 +1525,11 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Cleanup(cancel) // Create a mock server that captures the request body sent upstream. - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - bridge, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil)}, + ) // Inject adaptive thinking into the fixture request. reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) @@ -1712,7 +1537,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) + req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1735,34 +1560,26 @@ func TestEnvironmentDoNotLeak(t *testing.T) { testCases := []struct { name string fixture []byte - configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error) - createRequest func(*testing.T, string, []byte) *http.Request + providerFn providerFunc + createRequest createRequestFunc envVars map[string]string headerName string }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - }, - createRequest: createAnthropicMessagesReq, + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + providerFn: newAnthropicProvider, + createRequest: integrationtest.CreateAnthropicMessagesReq, envVars: map[string]string{ "ANTHROPIC_AUTH_TOKEN": "should-not-leak", }, headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present. }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - }, - createRequest: createOpenAIChatCompletionsReq, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + providerFn: newOpenAIProvider, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, envVars: map[string]string{ "OPENAI_ORG_ID": "should-not-leak", }, @@ -1778,7 +1595,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1786,18 +1603,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(upstream.URL, recorderClient) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{tc.providerFn(upstream.URL)}, + ) - req := tc.createRequest(t, mockSrv.URL, fix.Request()) + req := tc.createRequest(t, ts.URL, fix.Request()) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1826,9 +1636,9 @@ func TestActorHeaders(t *testing.T) { }{ { name: "openai/v1/chat/completions", - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := integrationtest.OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1837,9 +1647,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/chat/completions", - createRequest: createOpenAIChatCompletionsReq, + createRequest: integrationtest.CreateOpenAIChatCompletionsReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := integrationtest.OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1848,9 +1658,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: createOpenAIResponsesReq, + createRequest: integrationtest.CreateOpenAIResponsesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := integrationtest.OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1859,9 +1669,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: createOpenAIResponsesReq, + createRequest: integrationtest.CreateOpenAIResponsesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := integrationtest.OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1870,9 +1680,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := anthropicCfg(url, key) + cfg := integrationtest.AnthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewAnthropic(cfg, nil) }, @@ -1881,9 +1691,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: createAnthropicMessagesReq, + createRequest: integrationtest.CreateAnthropicMessagesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := anthropicCfg(url, key) + cfg := integrationtest.AnthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewAnthropic(cfg, nil) }, @@ -1912,30 +1722,21 @@ func TestActorHeaders(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - rec := &testutil.MockRecorder{} - provider := tc.createProviderFn(srv.URL, apiKey, send) - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err, "failed to create handler") - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) + p := tc.createProviderFn(srv.URL, integrationtest.APIKey, send) metadataKey := "Username" - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - // Attach an actor to the request context. - return aibcontext.AsActor(ctx, userID, recorder.Metadata{ + ts := integrationtest.NewBridgeTestServer(t, ctx, + []aibridge.Provider{p}, + integrationtest.WithActor(integrationtest.DefaultActorID, recorder.Metadata{ metadataKey: actorUsername, - }) - } - mockSrv.Start() + }), + ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) require.NoError(t, err) - req := tc.createRequest(t, mockSrv.URL, reqBody) + req := tc.createRequest(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1952,7 +1753,7 @@ func TestActorHeaders(t *testing.T) { } if send { - require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{userID}) + require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{integrationtest.DefaultActorID}) require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) } else { require.Empty(t, found) @@ -1978,36 +1779,4 @@ func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { return total } -func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/anthropic/v1/messages", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req -} - -func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/chat/completions", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req -} - -func openaiCfg(url, key string) config.OpenAI { - return config.OpenAI{ - BaseURL: url, - Key: key, - } -} - -func anthropicCfg(url, key string) config.Anthropic { - return config.Anthropic{ - BaseURL: url, - Key: key, - } -} diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 61ce20f..b178384 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -17,7 +17,6 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/integrationtest" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3/responses" From d6e6df048ca3d7558c9df43c8ae612663b12d46d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 16:23:01 +0000 Subject: [PATCH 05/32] refactor: convert bridge_test.go from external to internal test package Change package from integrationtest_test to integrationtest, remove the self-import, and strip all integrationtest. prefixes from identifiers. --- internal/integrationtest/bridge_test.go | 201 ++++++++++++------------ 1 file changed, 100 insertions(+), 101 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index c52f944..6edd053 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -1,4 +1,4 @@ -package integrationtest_test +package integrationtest import ( "bytes" @@ -21,7 +21,6 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" @@ -55,15 +54,15 @@ type ( ) func newAnthropicProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(integrationtest.AnthropicCfg(addr, integrationtest.APIKey), nil) + return provider.NewAnthropic(AnthropicCfg(addr, APIKey), nil) } func newOpenAIProvider(addr string) aibridge.Provider { - return provider.NewOpenAI(integrationtest.OpenAICfg(addr, integrationtest.APIKey)) + return provider.NewOpenAI(OpenAICfg(addr, APIKey)) } func newBedrockProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(integrationtest.AnthropicCfg(addr, integrationtest.APIKey), testBedrockCfg(addr)) + return provider.NewAnthropic(AnthropicCfg(addr, APIKey), testBedrockCfg(addr)) } func TestMain(m *testing.M) { @@ -104,16 +103,16 @@ func TestAnthropicMessages(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil)}, + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), nil)}, ) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -179,12 +178,12 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg("http://unused", integrationtest.APIKey), bedrockCfg)}, - integrationtest.WithLogger(integrationtest.NewLogger(t, &slogtest.Options{IgnoreErrors: true})), + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(AnthropicCfg("http://unused", APIKey), bedrockCfg)}, + WithLogger(NewLogger(t, &slogtest.Options{IgnoreErrors: true})), ) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + req := CreateAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -205,7 +204,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ @@ -217,16 +216,16 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), bedrockCfg)}, - integrationtest.WithLogger(integrationtest.NewLogger(t, &slogtest.Options{IgnoreErrors: true})), + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), bedrockCfg)}, + WithLogger(NewLogger(t, &slogtest.Options{IgnoreErrors: true})), ) // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -295,16 +294,16 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey))}, + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(OpenAICfg(upstream.URL, APIKey))}, ) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) + req := CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -378,20 +377,20 @@ func TestOpenAIChatCompletions(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix), integrationtest.NewFixtureToolResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix), NewFixtureToolResponse(fix)) // Setup MCP proxies with the tool from the fixture - mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) + mockMCP := SetupMCPForTest(t, DefaultTracer) - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey))}, - integrationtest.WithMCP(mockMCP), + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(OpenAICfg(upstream.URL, APIKey))}, + WithMCP(mockMCP), ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) + req := CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -409,7 +408,7 @@ func TestOpenAIChatCompletions(t *testing.T) { resp.Body.Close() // Verify the MCP tool was actually invoked - invocations := mockMCP.GetCallsByTool(integrationtest.MockToolName) + invocations := mockMCP.GetCallsByTool(MockToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked") // Verify tool was invoked with the expected args (if specified) @@ -424,7 +423,7 @@ func TestOpenAIChatCompletions(t *testing.T) { // Verify tool usage was recorded toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) - assert.Equal(t, integrationtest.MockToolName, toolUsages[0].Tool) + assert.Equal(t, MockToolName, toolUsages[0].Tool) ts.Recorder.VerifyAllInterceptionsEnded(t) }) @@ -512,7 +511,7 @@ func TestSimple(t *testing.T) { expectedPath: "/v1/messages", providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: CreateAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "claude-cli/2.0.67 (external, cli)", expectedClient: aibridge.ClientClaudeCode, @@ -524,7 +523,7 @@ func TestSimple(t *testing.T) { expectedPath: "/chat/completions", providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: CreateOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", expectedClient: aibridge.ClientCodex, @@ -536,7 +535,7 @@ func TestSimple(t *testing.T) { expectedPath: "/api/v1/messages", providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: CreateAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "GitHubCopilotChat/0.37.2026011603", expectedClient: aibridge.ClientCopilotVSC, @@ -548,7 +547,7 @@ func TestSimple(t *testing.T) { expectedPath: "/api/chat/completions", providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: CreateOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", expectedClient: aibridge.ClientZed, @@ -567,9 +566,9 @@ func TestSimple(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL + tc.basePath)}, ) @@ -802,9 +801,9 @@ func TestFallthrough(t *testing.T) { t.Parallel() fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, t.Context(), integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, t.Context(), NewFixtureResponse(fix)) p := tc.providerFn(upstream.URL + tc.basePath) - ts := integrationtest.NewBridgeTestServer(t, t.Context(), + ts := NewBridgeTestServer(t, t.Context(), []aibridge.Provider{p}, ) @@ -822,7 +821,7 @@ func TestFallthrough(t *testing.T) { received := upstream.ReceivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(p.AuthHeader()), integrationtest.APIKey) + require.Contains(t, received[0].Header.Get(p.AuthHeader()), APIKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -845,18 +844,18 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, integrationtest.DefaultTracer, integrationtest.DefaultActorID, integrationtest.CreateAnthropicMessagesReq, anthropicToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, DefaultTracer, DefaultActorID, CreateAnthropicMessagesReq, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, integrationtest.MockToolName, toolUsages[0].Tool) + require.Equal(t, MockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(integrationtest.MockToolName) + invocations := mockMCP.GetCallsByTool(MockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -929,18 +928,18 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, integrationtest.DefaultTracer, integrationtest.DefaultActorID, integrationtest.CreateOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, DefaultTracer, DefaultActorID, CreateOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, integrationtest.MockToolName, toolUsages[0].Tool) + require.Equal(t, MockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(integrationtest.MockToolName) + invocations := mockMCP.GetCallsByTool(MockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1103,7 +1102,7 @@ func setupInjectedToolTest( actorID string, createRequestFn func(*testing.T, string, []byte) *http.Request, toolRequestValidatorFn func(*http.Request, []byte), -) (*testutil.MockRecorder, *integrationtest.MockMCP, *http.Response) { +) (*testutil.MockRecorder, *MockMCP, *http.Response) { t.Helper() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -1113,18 +1112,18 @@ func setupInjectedToolTest( // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. - firstResp := integrationtest.NewFixtureResponse(fix) - toolResp := integrationtest.NewFixtureToolResponse(fix) + firstResp := NewFixtureResponse(fix) + toolResp := NewFixtureToolResponse(fix) toolResp.OnRequest = toolRequestValidatorFn - upstream := integrationtest.NewMockUpstream(t, ctx, firstResp, toolResp) + upstream := NewMockUpstream(t, ctx, firstResp, toolResp) - mockMCP := integrationtest.SetupMCPForTest(t, tracer) + mockMCP := SetupMCPForTest(t, tracer) - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{providerFn(upstream.URL)}, - integrationtest.WithMCP(mockMCP), - integrationtest.WithTracer(tracer), - integrationtest.WithActor(actorID, nil), + WithMCP(mockMCP), + WithTracer(tracer), + WithActor(actorID, nil), ) // Add the stream param to the request. @@ -1163,7 +1162,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderAnthropic, fixture: fixtures.AntNonStreamError, - createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + createRequestFunc: CreateAnthropicMessagesReq, providerFn: newAnthropicProvider, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1177,7 +1176,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatNonStreamError, - createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + createRequestFunc: CreateOpenAIChatCompletionsReq, providerFn: newOpenAIProvider, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1204,9 +1203,9 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. Error fixtures contain raw HTTP // responses that may cause the bridge to retry. fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, ) @@ -1239,7 +1238,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderAnthropic, fixture: fixtures.AntMidStreamError, - createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + createRequestFunc: CreateAnthropicMessagesReq, providerFn: newAnthropicProvider, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1254,7 +1253,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatMidStreamError, - createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + createRequestFunc: CreateOpenAIChatCompletionsReq, providerFn: newOpenAIProvider, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1282,10 +1281,10 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, ) @@ -1318,13 +1317,13 @@ func TestStableRequestEncoding(t *testing.T) { { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, - createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + createRequestFunc: CreateAnthropicMessagesReq, providerFn: newAnthropicProvider, }, { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, - createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + createRequestFunc: CreateOpenAIChatCompletionsReq, providerFn: newOpenAIProvider, }, } @@ -1337,21 +1336,21 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) + mockMCP := SetupMCPForTest(t, DefaultTracer) fix := fixtures.Parse(t, tc.fixture) // Create a mock upstream that serves the same blocking response for each request. count := 10 - responses := make([]integrationtest.UpstreamResponse, count) + responses := make([]UpstreamResponse, count) for i := range count { - responses[i] = integrationtest.NewFixtureResponse(fix) + responses[i] = NewFixtureResponse(fix) } - upstream := integrationtest.NewMockUpstream(t, ctx, responses...) + upstream := NewMockUpstream(t, ctx, responses...) - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, - integrationtest.WithMCP(mockMCP), + WithMCP(mockMCP), ) // Make multiple requests and verify they all have identical payloads. @@ -1458,24 +1457,24 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { // Setup MCP tools conditionally. var mcpMgr mcp.ServerProxier if tc.withInjectedTools { - mcpMgr = integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) + mcpMgr = SetupMCPForTest(t, DefaultTracer) } else { - mcpMgr = integrationtest.NewNoopMCPManager() + mcpMgr = NewNoopMCPManager() } fix := fixtures.Parse(t, fixtures.AntSimple) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil)}, - integrationtest.WithMCP(mcpMgr), + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), nil)}, + WithMCP(mcpMgr), ) // Prepare request body with tool_choice set. reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1525,10 +1524,10 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Cleanup(cancel) // Create a mock server that captures the request body sent upstream. - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) - ts := integrationtest.NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil)}, + ts := NewBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), nil)}, ) // Inject adaptive thinking into the fixture request. @@ -1537,7 +1536,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1569,7 +1568,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, providerFn: newAnthropicProvider, - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: CreateAnthropicMessagesReq, envVars: map[string]string{ "ANTHROPIC_AUTH_TOKEN": "should-not-leak", }, @@ -1579,7 +1578,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, providerFn: newOpenAIProvider, - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: CreateOpenAIChatCompletionsReq, envVars: map[string]string{ "OPENAI_ORG_ID": "should-not-leak", }, @@ -1595,7 +1594,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1603,7 +1602,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, ) @@ -1636,9 +1635,9 @@ func TestActorHeaders(t *testing.T) { }{ { name: "openai/v1/chat/completions", - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: CreateOpenAIChatCompletionsReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := integrationtest.OpenAICfg(url, key) + cfg := OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1647,9 +1646,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/chat/completions", - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: CreateOpenAIChatCompletionsReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := integrationtest.OpenAICfg(url, key) + cfg := OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1658,9 +1657,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: integrationtest.CreateOpenAIResponsesReq, + createRequest: CreateOpenAIResponsesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := integrationtest.OpenAICfg(url, key) + cfg := OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1669,9 +1668,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: integrationtest.CreateOpenAIResponsesReq, + createRequest: CreateOpenAIResponsesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := integrationtest.OpenAICfg(url, key) + cfg := OpenAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1680,9 +1679,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: CreateAnthropicMessagesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := integrationtest.AnthropicCfg(url, key) + cfg := AnthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewAnthropic(cfg, nil) }, @@ -1691,9 +1690,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: CreateAnthropicMessagesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := integrationtest.AnthropicCfg(url, key) + cfg := AnthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewAnthropic(cfg, nil) }, @@ -1722,12 +1721,12 @@ func TestActorHeaders(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - p := tc.createProviderFn(srv.URL, integrationtest.APIKey, send) + p := tc.createProviderFn(srv.URL, APIKey, send) metadataKey := "Username" - ts := integrationtest.NewBridgeTestServer(t, ctx, + ts := NewBridgeTestServer(t, ctx, []aibridge.Provider{p}, - integrationtest.WithActor(integrationtest.DefaultActorID, recorder.Metadata{ + WithActor(DefaultActorID, recorder.Metadata{ metadataKey: actorUsername, }), ) @@ -1753,7 +1752,7 @@ func TestActorHeaders(t *testing.T) { } if send { - require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{integrationtest.DefaultActorID}) + require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{DefaultActorID}) require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) } else { require.Empty(t, found) From 0c8152e362b2f986916046930602324609c0c92a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 16:27:12 +0000 Subject: [PATCH 06/32] refactor: use internal test package, unexport all helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch all _test.go files from package integrationtest_test to package integrationtest. Since nothing outside this directory imports the package, all exported symbols are unexported: - Types: BridgeTestServer → bridgeTestServer, MockMCP → mockMCP, MockUpstream → mockUpstream, UpstreamResponse → upstreamResponse, etc. - Functions: NewBridgeTestServer → newBridgeTestServer, CreateAnthropicMessagesReq → createAnthropicMessagesReq, etc. - Options: WithMetrics → withMetrics, WithTracer → withTracer, etc. - Constants: APIKey → apiKey, DefaultActorID → defaultActorID, etc. - Methods: GetCallsByTool → getCallsByTool, ReceivedRequests → receivedRequests, etc. This eliminates the integrationtest. prefix from all call sites and removes the artificial export surface. All tests pass (go test ./... -count=1). --- internal/integrationtest/apidump_test.go | 25 +-- internal/integrationtest/bridge.go | 72 +++--- internal/integrationtest/bridge_test.go | 212 +++++++++--------- .../integrationtest/circuit_breaker_test.go | 41 ++-- internal/integrationtest/metrics_test.go | 93 ++++---- internal/integrationtest/mockmcp.go | 30 +-- internal/integrationtest/requests.go | 24 +- internal/integrationtest/responses_test.go | 51 +++-- internal/integrationtest/trace_test.go | 103 +++++---- internal/integrationtest/upstream.go | 58 ++--- 10 files changed, 352 insertions(+), 357 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index fd13883..52d8e9f 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -1,4 +1,4 @@ -package integrationtest_test +package integrationtest import ( "bufio" @@ -17,7 +17,6 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept/apidump" - "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/provider" "github.com/stretchr/testify/require" ) @@ -51,25 +50,25 @@ func TestAPIDump(t *testing.T) { name: "anthropic", fixture: fixtures.AntSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir), nil)} + return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} }, - createRequestFunc: integrationtest.CreateAnthropicMessagesReq, + createRequestFunc: createAnthropicMessagesReq, }, { name: "openai_chat_completions", fixture: fixtures.OaiChatSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir))} + return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} }, - createRequestFunc: integrationtest.CreateOpenAIChatCompletionsReq, + createRequestFunc: createOpenAIChatCompletionsReq, }, { name: "openai_responses", fixture: fixtures.OaiResponsesBlockingSimple, providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir))} + return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} }, - createRequestFunc: integrationtest.CreateOpenAIResponsesReq, + createRequestFunc: createOpenAIResponsesReq, }, } @@ -82,13 +81,13 @@ func TestAPIDump(t *testing.T) { // Setup mock upstream server. fix := fixtures.Parse(t, tc.fixture) - srv := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // Create temp dir for API dumps. dumpDir := t.TempDir() providers := tc.providersFunc(srv.URL, dumpDir) - ts := integrationtest.NewBridgeTestServer(t, ctx, providers) + ts := newBridgeTestServer(t, ctx, providers) req := tc.createRequestFunc(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) @@ -172,7 +171,7 @@ func TestAPIDumpPassthrough(t *testing.T) { { name: "anthropic", providerFunc: func(addr string, dumpDir string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir), nil) + return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, requestPath: "/anthropic/v1/models", expectDumpName: "-v1-models-", @@ -180,7 +179,7 @@ func TestAPIDumpPassthrough(t *testing.T) { { name: "openai", providerFunc: func(addr string, dumpDir string) aibridge.Provider { - return provider.NewOpenAI(openaiCfgWithAPIDump(addr, integrationtest.APIKey, dumpDir)) + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, requestPath: "/openai/v1/models", expectDumpName: "-models-", @@ -211,7 +210,7 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() prov := tc.providerFunc(upstream.URL, dumpDir) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+tc.requestPath, nil) require.NoError(t, err) diff --git a/internal/integrationtest/bridge.go b/internal/integrationtest/bridge.go index 0c4da14..62f2755 100644 --- a/internal/integrationtest/bridge.go +++ b/internal/integrationtest/bridge.go @@ -19,15 +19,15 @@ import ( "go.opentelemetry.io/otel/trace" ) -// DefaultActorID is the actor ID used by default in test servers. -const DefaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" +// defaultActorID is the actor ID used by default in test servers. +const defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" -// DefaultTracer is the default OTel tracer used in integration tests. -var DefaultTracer = otel.Tracer("integrationtest") +// defaultTracer is the default OTel tracer used in integration tests. +var defaultTracer = otel.Tracer("integrationtest") -// NewLogger creates a test logger at Debug level. +// newLogger creates a test logger at Debug level. // Eliminates the repeated slogtest.Make(t, &slogtest.Options{...}).Leveled(slog.LevelDebug) pattern. -func NewLogger(t *testing.T, opts ...*slogtest.Options) slog.Logger { +func newLogger(t *testing.T, opts ...*slogtest.Options) slog.Logger { t.Helper() var o *slogtest.Options if len(opts) > 0 { @@ -38,15 +38,15 @@ func NewLogger(t *testing.T, opts ...*slogtest.Options) slog.Logger { return slogtest.Make(t, o).Leveled(slog.LevelDebug) } -// BridgeTestServer wraps an httptest.Server running a RequestBridge. -type BridgeTestServer struct { +// bridgeTestServer wraps an httptest.Server running a RequestBridge. +type bridgeTestServer struct { *httptest.Server Recorder *testutil.MockRecorder Bridge *aibridge.RequestBridge } -// BridgeOption configures a [BridgeTestServer]. -type BridgeOption func(*bridgeConfig) +// bridgeOption configures a [bridgeTestServer]. +type bridgeOption func(*bridgeConfig) type bridgeConfig struct { metrics *metrics.Metrics @@ -59,66 +59,66 @@ type bridgeConfig struct { wrapRecorder bool } -// WithMetrics sets the Prometheus metrics for the bridge. -func WithMetrics(m *metrics.Metrics) BridgeOption { +// withMetrics sets the Prometheus metrics for the bridge. +func withMetrics(m *metrics.Metrics) bridgeOption { return func(c *bridgeConfig) { c.metrics = m } } -// WithTracer overrides the default tracer. -func WithTracer(t trace.Tracer) BridgeOption { +// withTracer overrides the default tracer. +func withTracer(t trace.Tracer) bridgeOption { return func(c *bridgeConfig) { c.tracer = t } } -// WithMCP sets the MCP server proxier (default: NoopMCPManager). -func WithMCP(p mcp.ServerProxier) BridgeOption { +// withMCP sets the MCP server proxier (default: NoopMCPManager). +func withMCP(p mcp.ServerProxier) bridgeOption { return func(c *bridgeConfig) { c.mcpProxy = p } } -// WithActor sets the actor ID and metadata for the BaseContext. -func WithActor(id string, md recorder.Metadata) BridgeOption { +// withActor sets the actor ID and metadata for the BaseContext. +func withActor(id string, md recorder.Metadata) bridgeOption { return func(c *bridgeConfig) { c.userID = id; c.metadata = md } } -// WithLogger overrides the default slogtest debug logger. -func WithLogger(l slog.Logger) BridgeOption { +// withLogger overrides the default slogtest debug logger. +func withLogger(l slog.Logger) bridgeOption { return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } } -// WithWrappedRecorder wraps the MockRecorder through aibridge.NewRecorder +// withWrappedRecorder wraps the MockRecorder through aibridge.NewRecorder // (the production recorder wrapper). Use when testing the recorder pipeline. -func WithWrappedRecorder() BridgeOption { +func withWrappedRecorder() bridgeOption { return func(c *bridgeConfig) { c.wrapRecorder = true } } -// NewBridgeTestServer creates a fully configured test server running +// newBridgeTestServer creates a fully configured test server running // a RequestBridge with sensible defaults: -// - MockRecorder (raw, unless WithWrappedRecorder) -// - NoopMCPManager (unless WithMCP) -// - slogtest debug logger (unless WithLogger) -// - DefaultTracer (unless WithTracer) -// - DefaultActorID (unless WithActor) -func NewBridgeTestServer( +// - MockRecorder (raw, unless withWrappedRecorder) +// - NoopMCPManager (unless withMCP) +// - slogtest debug logger (unless withLogger) +// - defaultTracer (unless withTracer) +// - defaultActorID (unless withActor) +func newBridgeTestServer( t *testing.T, ctx context.Context, providers []aibridge.Provider, - opts ...BridgeOption, -) *BridgeTestServer { + opts ...bridgeOption, +) *bridgeTestServer { t.Helper() cfg := &bridgeConfig{ - userID: DefaultActorID, + userID: defaultActorID, } for _, o := range opts { o(cfg) } if cfg.tracer == nil { - cfg.tracer = DefaultTracer + cfg.tracer = defaultTracer } if !cfg.loggerSet { - cfg.logger = NewLogger(t) + cfg.logger = newLogger(t) } if cfg.mcpProxy == nil { - cfg.mcpProxy = NewNoopMCPManager() + cfg.mcpProxy = newNoopMCPManager() } mockRec := &testutil.MockRecorder{} @@ -143,7 +143,7 @@ func NewBridgeTestServer( srv.Start() t.Cleanup(srv.Close) - return &BridgeTestServer{ + return &bridgeTestServer{ Server: srv, Recorder: mockRec, Bridge: bridge, diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 6edd053..5ebbef0 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -54,15 +54,15 @@ type ( ) func newAnthropicProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(AnthropicCfg(addr, APIKey), nil) + return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) } func newOpenAIProvider(addr string) aibridge.Provider { - return provider.NewOpenAI(OpenAICfg(addr, APIKey)) + return provider.NewOpenAI(openAICfg(addr, apiKey)) } func newBedrockProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(AnthropicCfg(addr, APIKey), testBedrockCfg(addr)) + return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) } func TestMain(m *testing.M) { @@ -103,16 +103,16 @@ func TestAnthropicMessages(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), nil)}, + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, ) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -178,12 +178,12 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(AnthropicCfg("http://unused", APIKey), bedrockCfg)}, - WithLogger(NewLogger(t, &slogtest.Options{IgnoreErrors: true})), + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)}, + withLogger(newLogger(t, &slogtest.Options{IgnoreErrors: true})), ) - req := CreateAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + req := createAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -204,7 +204,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ @@ -216,16 +216,16 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), bedrockCfg)}, - WithLogger(NewLogger(t, &slogtest.Options{IgnoreErrors: true})), + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, + withLogger(newLogger(t, &slogtest.Options{IgnoreErrors: true})), ) // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -240,7 +240,7 @@ func TestAWSBedrockIntegration(t *testing.T) { // Verify that Bedrock-specific model name was used in the request to the mock server // and the interception data. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" @@ -294,16 +294,16 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(OpenAICfg(upstream.URL, APIKey))}, + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, ) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) + req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -377,20 +377,20 @@ func TestOpenAIChatCompletions(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix), NewFixtureToolResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // Setup MCP proxies with the tool from the fixture - mockMCP := SetupMCPForTest(t, DefaultTracer) + mockMCP := setupMCPForTest(t, defaultTracer) - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(OpenAICfg(upstream.URL, APIKey))}, - WithMCP(mockMCP), + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, + withMCP(mockMCP), ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - req := CreateOpenAIChatCompletionsReq(t, ts.URL, reqBody) + req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -408,7 +408,7 @@ func TestOpenAIChatCompletions(t *testing.T) { resp.Body.Close() // Verify the MCP tool was actually invoked - invocations := mockMCP.GetCallsByTool(MockToolName) + invocations := mockMCP.getCallsByTool(mockToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked") // Verify tool was invoked with the expected args (if specified) @@ -423,7 +423,7 @@ func TestOpenAIChatCompletions(t *testing.T) { // Verify tool usage was recorded toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) - assert.Equal(t, MockToolName, toolUsages[0].Tool) + assert.Equal(t, mockToolName, toolUsages[0].Tool) ts.Recorder.VerifyAllInterceptionsEnded(t) }) @@ -511,7 +511,7 @@ func TestSimple(t *testing.T) { expectedPath: "/v1/messages", providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "claude-cli/2.0.67 (external, cli)", expectedClient: aibridge.ClientClaudeCode, @@ -523,7 +523,7 @@ func TestSimple(t *testing.T) { expectedPath: "/chat/completions", providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", expectedClient: aibridge.ClientCodex, @@ -535,7 +535,7 @@ func TestSimple(t *testing.T) { expectedPath: "/api/v1/messages", providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "GitHubCopilotChat/0.37.2026011603", expectedClient: aibridge.ClientCopilotVSC, @@ -547,7 +547,7 @@ func TestSimple(t *testing.T) { expectedPath: "/api/chat/completions", providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", expectedClient: aibridge.ClientZed, @@ -566,9 +566,9 @@ func TestSimple(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL + tc.basePath)}, ) @@ -584,7 +584,7 @@ func TestSimple(t *testing.T) { defer resp.Body.Close() // Then: I expect the upstream request to have the correct path. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedPath, received[0].Path) @@ -801,9 +801,9 @@ func TestFallthrough(t *testing.T) { t.Parallel() fix := fixtures.Parse(t, tc.fixture) - upstream := NewMockUpstream(t, t.Context(), NewFixtureResponse(fix)) + upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) p := tc.providerFn(upstream.URL + tc.basePath) - ts := NewBridgeTestServer(t, t.Context(), + ts := newBridgeTestServer(t, t.Context(), []aibridge.Provider{p}, ) @@ -818,10 +818,10 @@ func TestFallthrough(t *testing.T) { // Verify upstream received the request at the expected path // with the API key header. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(p.AuthHeader()), APIKey) + require.Contains(t, received[0].Header.Get(p.AuthHeader()), apiKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -844,18 +844,18 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, DefaultTracer, DefaultActorID, CreateAnthropicMessagesReq, anthropicToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, defaultTracer, defaultActorID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, MockToolName, toolUsages[0].Tool) + require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(MockToolName) + invocations := mockMCP.getCallsByTool(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -928,18 +928,18 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, DefaultTracer, DefaultActorID, CreateOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, defaultTracer, defaultActorID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, MockToolName, toolUsages[0].Tool) + require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(MockToolName) + invocations := mockMCP.getCallsByTool(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1102,7 +1102,7 @@ func setupInjectedToolTest( actorID string, createRequestFn func(*testing.T, string, []byte) *http.Request, toolRequestValidatorFn func(*http.Request, []byte), -) (*testutil.MockRecorder, *MockMCP, *http.Response) { +) (*testutil.MockRecorder, *mockMCP, *http.Response) { t.Helper() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -1112,18 +1112,18 @@ func setupInjectedToolTest( // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. - firstResp := NewFixtureResponse(fix) - toolResp := NewFixtureToolResponse(fix) + firstResp := newFixtureResponse(fix) + toolResp := newFixtureToolResponse(fix) toolResp.OnRequest = toolRequestValidatorFn - upstream := NewMockUpstream(t, ctx, firstResp, toolResp) + upstream := newMockUpstream(t, ctx, firstResp, toolResp) - mockMCP := SetupMCPForTest(t, tracer) + mockMCP := setupMCPForTest(t, tracer) - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{providerFn(upstream.URL)}, - WithMCP(mockMCP), - WithTracer(tracer), - WithActor(actorID, nil), + withMCP(mockMCP), + withTracer(tracer), + withActor(actorID, nil), ) // Add the stream param to the request. @@ -1162,7 +1162,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderAnthropic, fixture: fixtures.AntNonStreamError, - createRequestFunc: CreateAnthropicMessagesReq, + createRequestFunc: createAnthropicMessagesReq, providerFn: newAnthropicProvider, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1176,7 +1176,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatNonStreamError, - createRequestFunc: CreateOpenAIChatCompletionsReq, + createRequestFunc: createOpenAIChatCompletionsReq, providerFn: newOpenAIProvider, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) @@ -1203,9 +1203,9 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. Error fixtures contain raw HTTP // responses that may cause the bridge to retry. fix := fixtures.Parse(t, tc.fixture) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, ) @@ -1238,7 +1238,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderAnthropic, fixture: fixtures.AntMidStreamError, - createRequestFunc: CreateAnthropicMessagesReq, + createRequestFunc: createAnthropicMessagesReq, providerFn: newAnthropicProvider, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1253,7 +1253,7 @@ func TestErrorHandling(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatMidStreamError, - createRequestFunc: CreateOpenAIChatCompletionsReq, + createRequestFunc: createOpenAIChatCompletionsReq, providerFn: newOpenAIProvider, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. @@ -1281,10 +1281,10 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. fix := fixtures.Parse(t, tc.fixture) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, ) @@ -1317,13 +1317,13 @@ func TestStableRequestEncoding(t *testing.T) { { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, - createRequestFunc: CreateAnthropicMessagesReq, + createRequestFunc: createAnthropicMessagesReq, providerFn: newAnthropicProvider, }, { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, - createRequestFunc: CreateOpenAIChatCompletionsReq, + createRequestFunc: createOpenAIChatCompletionsReq, providerFn: newOpenAIProvider, }, } @@ -1336,21 +1336,21 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mockMCP := SetupMCPForTest(t, DefaultTracer) + mockMCP := setupMCPForTest(t, defaultTracer) fix := fixtures.Parse(t, tc.fixture) // Create a mock upstream that serves the same blocking response for each request. count := 10 - responses := make([]UpstreamResponse, count) + responses := make([]upstreamResponse, count) for i := range count { - responses[i] = NewFixtureResponse(fix) + responses[i] = newFixtureResponse(fix) } - upstream := NewMockUpstream(t, ctx, responses...) + upstream := newMockUpstream(t, ctx, responses...) - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, - WithMCP(mockMCP), + withMCP(mockMCP), ) // Make multiple requests and verify they all have identical payloads. @@ -1364,7 +1364,7 @@ func TestStableRequestEncoding(t *testing.T) { } // All upstream request bodies should be identical. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, count) reference := string(received[0].Body) for _, r := range received[1:] { @@ -1457,24 +1457,24 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { // Setup MCP tools conditionally. var mcpMgr mcp.ServerProxier if tc.withInjectedTools { - mcpMgr = SetupMCPForTest(t, DefaultTracer) + mcpMgr = setupMCPForTest(t, defaultTracer) } else { - mcpMgr = NewNoopMCPManager() + mcpMgr = newNoopMCPManager() } fix := fixtures.Parse(t, fixtures.AntSimple) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), nil)}, - WithMCP(mcpMgr), + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, + withMCP(mcpMgr), ) // Prepare request body with tool_choice set. reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1482,7 +1482,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { _ = resp.Body.Close() // Verify tool_choice in the upstream request. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) var receivedRequest map[string]any require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) @@ -1524,10 +1524,10 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Cleanup(cancel) // Create a mock server that captures the request body sent upstream. - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := NewBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(AnthropicCfg(upstream.URL, APIKey), nil)}, + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, ) // Inject adaptive thinking into the fixture request. @@ -1536,7 +1536,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - req := CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1544,7 +1544,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { _ = resp.Body.Close() // Verify the thinking field was preserved in the upstream request. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) }) @@ -1568,7 +1568,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, providerFn: newAnthropicProvider, - createRequest: CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, envVars: map[string]string{ "ANTHROPIC_AUTH_TOKEN": "should-not-leak", }, @@ -1578,7 +1578,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, providerFn: newOpenAIProvider, - createRequest: CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, envVars: map[string]string{ "OPENAI_ORG_ID": "should-not-leak", }, @@ -1594,7 +1594,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := NewMockUpstream(t, ctx, NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1602,7 +1602,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{tc.providerFn(upstream.URL)}, ) @@ -1614,7 +1614,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { defer resp.Body.Close() // Verify that environment values did not leak. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) require.Empty(t, received[0].Header.Get(tc.headerName)) }) @@ -1635,9 +1635,9 @@ func TestActorHeaders(t *testing.T) { }{ { name: "openai/v1/chat/completions", - createRequest: CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := OpenAICfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1646,9 +1646,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/chat/completions", - createRequest: CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := OpenAICfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1657,9 +1657,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: CreateOpenAIResponsesReq, + createRequest: createOpenAIResponsesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := OpenAICfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1668,9 +1668,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: CreateOpenAIResponsesReq, + createRequest: createOpenAIResponsesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := OpenAICfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1679,9 +1679,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := AnthropicCfg(url, key) + cfg := anthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewAnthropic(cfg, nil) }, @@ -1690,9 +1690,9 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := AnthropicCfg(url, key) + cfg := anthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewAnthropic(cfg, nil) }, @@ -1721,12 +1721,12 @@ func TestActorHeaders(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - p := tc.createProviderFn(srv.URL, APIKey, send) + p := tc.createProviderFn(srv.URL, apiKey, send) metadataKey := "Username" - ts := NewBridgeTestServer(t, ctx, + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{p}, - WithActor(DefaultActorID, recorder.Metadata{ + withActor(defaultActorID, recorder.Metadata{ metadataKey: actorUsername, }), ) @@ -1752,7 +1752,7 @@ func TestActorHeaders(t *testing.T) { } if send { - require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{DefaultActorID}) + require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{defaultActorID}) require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) } else { require.Empty(t, found) diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 61bd01f..014828f 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -1,4 +1,4 @@ -package integrationtest_test +package integrationtest import ( "fmt" @@ -13,7 +13,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" - "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" @@ -67,7 +66,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -87,7 +86,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -134,9 +133,9 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithActor("test-user-id", nil), + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withActor("test-user-id", nil), ) makeRequest := func() *http.Response { @@ -236,7 +235,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -255,7 +254,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -293,9 +292,9 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithActor("test-user-id", nil), + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withActor("test-user-id", nil), ) makeRequest := func() *http.Response { @@ -377,7 +376,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: integrationtest.CreateAnthropicMessagesReq, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -397,7 +396,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: integrationtest.CreateOpenAIChatCompletionsReq, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -445,9 +444,9 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithActor("test-user-id", nil), + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withActor("test-user-id", nil), ) makeRequest := func() *http.Response { @@ -569,14 +568,14 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { }, nil) ctx := t.Context() - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithActor("test-user-id", nil), + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withActor("test-user-id", nil), ) makeRequest := func(model string) *http.Response { body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, []byte(body)) + req := createAnthropicMessagesReq(t, ts.URL, []byte(body)) req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") resp, err := http.DefaultClient.Do(req) diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index 8242221..642c7d7 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -1,4 +1,4 @@ -package integrationtest_test +package integrationtest import ( "context" @@ -11,7 +11,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" @@ -35,7 +34,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_simple", fixture: fixtures.AntSimple, - reqFunc: integrationtest.CreateAnthropicMessagesReq, + reqFunc: createAnthropicMessagesReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -44,7 +43,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_error", fixture: fixtures.AntNonStreamError, - reqFunc: integrationtest.CreateAnthropicMessagesReq, + reqFunc: createAnthropicMessagesReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -54,7 +53,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_simple", fixture: fixtures.OaiChatSimple, - reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, + reqFunc: createOpenAIChatCompletionsReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -63,7 +62,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_error", fixture: fixtures.OaiChatNonStreamError, - reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, + reqFunc: createOpenAIChatCompletionsReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -73,7 +72,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -82,7 +81,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_error", fixture: fixtures.OaiResponsesBlockingHttpErr, - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -92,7 +91,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -101,7 +100,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_error", fixture: fixtures.OaiResponsesStreamingHttpErr, - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -118,19 +117,19 @@ func TestMetrics_Interception(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) upstream.AllowOverflow = tc.allowOverflow m := aibridge.NewMetrics(prometheus.NewRegistry()) var prov aibridge.Provider if tc.expectProvider == config.ProviderAnthropic { - prov = provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil) + prov = provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) } else { - prov = provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) + prov = provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) } - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithWrappedRecorder(), + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withWrappedRecorder(), ) req := tc.reqFunc(t, ts.URL, fix.Request()) @@ -140,7 +139,7 @@ func TestMetrics_Interception(t *testing.T) { _, _ = io.ReadAll(resp.Body) count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( - tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", integrationtest.DefaultActorID)) + tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", defaultActorID)) require.Equal(t, 1.0, count) require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionDuration)) require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionCount)) @@ -165,17 +164,17 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { t.Cleanup(srv.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewAnthropic(integrationtest.AnthropicCfg(srv.URL, integrationtest.APIKey), nil) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithWrappedRecorder(), + prov := provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withWrappedRecorder(), ) // Make request in background. doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, fix.Request()) + req := createAnthropicMessagesReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() @@ -213,10 +212,10 @@ func TestMetrics_PassthroughCount(t *testing.T) { t.Cleanup(upstream.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithWrappedRecorder(), + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, + withMetrics(m), + withWrappedRecorder(), ) req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) @@ -239,16 +238,16 @@ func TestMetrics_PromptCount(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSimple) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithWrappedRecorder(), + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withWrappedRecorder(), ) - req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) + req := createOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -256,7 +255,7 @@ func TestMetrics_PromptCount(t *testing.T) { _, _ = io.ReadAll(resp.Body) prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( - config.ProviderOpenAI, "gpt-4.1", integrationtest.DefaultActorID)) + config.ProviderOpenAI, "gpt-4.1", defaultActorID)) require.Equal(t, 1.0, prompts) } @@ -267,16 +266,16 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithWrappedRecorder(), + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withWrappedRecorder(), ) - req := integrationtest.CreateOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) + req := createOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -296,20 +295,20 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // First request returns the tool invocation, the second returns the mocked response to the tool result. fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix), integrationtest.NewFixtureToolResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), nil) + prov := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) // Setup mocked MCP server & tools. - mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) + mockMCP := setupMCPForTest(t, defaultTracer) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithMetrics(m), - integrationtest.WithMCP(mockMCP), + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withMetrics(m), + withMCP(mockMCP), ) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, fix.Request()) + req := createAnthropicMessagesReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -328,6 +327,6 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { actualServerURL := *recorder.ToolUsages()[0].ServerURL count := promtest.ToFloat64(m.InjectedToolUseCount.WithLabelValues( - config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, integrationtest.MockToolName)) + config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) require.Equal(t, 1.0, count) } diff --git a/internal/integrationtest/mockmcp.go b/internal/integrationtest/mockmcp.go index 67c06ec..401ddea 100644 --- a/internal/integrationtest/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -21,33 +21,33 @@ import ( "go.opentelemetry.io/otel/trace/noop" ) -// MockToolName is the primary mock tool name used in MCP tests. -const MockToolName = "coder_list_workspaces" +// mockToolName is the primary mock tool name used in MCP tests. +const mockToolName = "coder_list_workspaces" -// MockMCP wraps a real mcp.ServerProxier with test assertion helpers. +// mockMCP wraps a real mcp.ServerProxier with test assertion helpers. // Implements mcp.ServerProxier so it can be passed directly to NewRequestBridge. -type MockMCP struct { +type mockMCP struct { mcp.ServerProxier calls *callAccumulator } -// GetCallsByTool returns recorded arguments for a given tool name. -func (m *MockMCP) GetCallsByTool(name string) []any { +// getCallsByTool returns recorded arguments for a given tool name. +func (m *mockMCP) getCallsByTool(name string) []any { return m.calls.getCallsByTool(name) } -// SetToolError configures a tool to return an error when invoked. -func (m *MockMCP) SetToolError(tool, errMsg string) { +// setToolError configures a tool to return an error when invoked. +func (m *mockMCP) setToolError(tool, errMsg string) { m.calls.setToolError(tool, errMsg) } -// SetupMCPForTest creates a ready-to-use MCP server with proxy named "coder". -func SetupMCPForTest(t *testing.T, tracer trace.Tracer) *MockMCP { +// setupMCPForTest creates a ready-to-use MCP server with proxy named "coder". +func setupMCPForTest(t *testing.T, tracer trace.Tracer) *mockMCP { t.Helper() - return SetupMCPForTestWithName(t, "coder", tracer) + return setupMCPForTestWithName(t, "coder", tracer) } -func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *MockMCP { +func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mockMCP { t.Helper() srv, acc := createMockMCPSrv(t) @@ -75,10 +75,10 @@ func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *Mo require.NoError(t, mgr.Init(ctx)) require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init") - return &MockMCP{ServerProxier: mgr, calls: acc} + return &mockMCP{ServerProxier: mgr, calls: acc} } -func NewNoopMCPManager() mcp.ServerProxier { +func newNoopMCPManager() mcp.ServerProxier { return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer("")) } @@ -134,7 +134,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { acc := newCallAccumulator() - for _, name := range []string{MockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} { + for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} { tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), ) diff --git a/internal/integrationtest/requests.go b/internal/integrationtest/requests.go index 92fe817..738ac00 100644 --- a/internal/integrationtest/requests.go +++ b/internal/integrationtest/requests.go @@ -9,11 +9,11 @@ import ( "github.com/stretchr/testify/require" ) -// APIKey is the default API key used across integration tests. -const APIKey = "api-key" +// apiKey is the default API key used across integration tests. +const apiKey = "api-key" -// CreateAnthropicMessagesReq builds an HTTP request targeting the Anthropic messages endpoint. -func CreateAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { +// createAnthropicMessagesReq builds an HTTP request targeting the Anthropic messages endpoint. +func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { t.Helper() req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/anthropic/v1/messages", bytes.NewReader(input)) @@ -23,8 +23,8 @@ func CreateAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *htt return req } -// CreateOpenAIChatCompletionsReq builds an HTTP request targeting the OpenAI chat completions endpoint. -func CreateOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { +// createOpenAIChatCompletionsReq builds an HTTP request targeting the OpenAI chat completions endpoint. +func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { t.Helper() req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/chat/completions", bytes.NewReader(input)) @@ -34,8 +34,8 @@ func CreateOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) return req } -// CreateOpenAIResponsesReq builds an HTTP request targeting the OpenAI responses endpoint. -func CreateOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { +// createOpenAIResponsesReq builds an HTTP request targeting the OpenAI responses endpoint. +func createOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { t.Helper() req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/responses", bytes.NewReader(input)) @@ -44,16 +44,16 @@ func CreateOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http. return req } -// OpenAICfg creates a minimal OpenAI config for testing. -func OpenAICfg(url, key string) config.OpenAI { +// openAICfg creates a minimal OpenAI config for testing. +func openAICfg(url, key string) config.OpenAI { return config.OpenAI{ BaseURL: url, Key: key, } } -// AnthropicCfg creates a minimal Anthropic config for testing. -func AnthropicCfg(url, key string) config.Anthropic { +// anthropicCfg creates a minimal Anthropic config for testing. +func anthropicCfg(url, key string) config.Anthropic { return config.Anthropic{ BaseURL: url, Key: key, diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index b178384..84ba5f3 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -1,4 +1,4 @@ -package integrationtest_test +package integrationtest import ( "context" @@ -16,7 +16,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3/responses" @@ -333,12 +332,12 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, integrationtest.WithWrappedRecorder()) + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, withWrappedRecorder()) - req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, fix.Request()) + req := createOpenAIResponsesReq(t, ts.URL, fix.Request()) req.Header.Set("User-Agent", tc.userAgent) client := &http.Client{} @@ -358,7 +357,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { interceptions := ts.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) intc := interceptions[0] - require.Equal(t, intc.InitiatorID, integrationtest.DefaultActorID) + require.Equal(t, intc.InitiatorID, defaultActorID) require.Equal(t, intc.Provider, config.ProviderOpenAI) require.Equal(t, intc.Model, tc.expectModel) require.Equal(t, tc.userAgent, intc.UserAgent) @@ -427,12 +426,12 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) - req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, reqBytes) + req := createOpenAIResponsesReq(t, ts.URL, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -540,10 +539,10 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) - req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, []byte(tc.request)) + req := createOpenAIResponsesReq(t, ts.URL, []byte(tc.request)) client := &http.Client{} resp, err := client.Do(req) @@ -600,11 +599,11 @@ func TestClientAndConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(tc.addr, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, integrationtest.WithWrappedRecorder()) + prov := provider.NewOpenAI(openAICfg(tc.addr, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, withWrappedRecorder()) reqBytes := responsesRequestBytes(t, tc.streaming) - req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, reqBytes) + req := createOpenAIResponsesReq(t, ts.URL, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -683,11 +682,11 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) reqBytes := responsesRequestBytes(t, tc.streaming) - req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, reqBytes) + req := createOpenAIResponsesReq(t, ts.URL, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -858,18 +857,18 @@ func TestResponsesInjectedTool(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix), integrationtest.NewFixtureToolResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // Setup MCP server proxies (with mock tools). - mockMCP := integrationtest.SetupMCPForTest(t, integrationtest.DefaultTracer) + mockMCP := setupMCPForTest(t, defaultTracer) if tc.expectToolError != "" { - mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError) + mockMCP.setToolError(tc.mcpToolName, tc.expectToolError) } - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, integrationtest.WithMCP(mockMCP)) + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, withMCP(mockMCP)) - req := integrationtest.CreateOpenAIResponsesReq(t, ts.URL, fix.Request()) + req := createOpenAIResponsesReq(t, ts.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -884,7 +883,7 @@ func TestResponsesInjectedTool(t *testing.T) { }, time.Second*10, time.Millisecond*50) // Verify the injected tool was invoked via MCP. - invocations := mockMCP.GetCallsByTool(tc.mcpToolName) + invocations := mockMCP.getCallsByTool(tc.mcpToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked once") // Verify the injected tool usage was recorded. diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 41dd968..eb71a6d 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -1,4 +1,4 @@ -package integrationtest_test +package integrationtest import ( "context" @@ -12,7 +12,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/integrationtest" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" @@ -99,21 +98,21 @@ func TestTraceAnthropic(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) var bedrockCfg *config.AWSBedrock if tc.bedrock { bedrockCfg = testBedrockCfg(upstream.URL) } - prov := provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), bedrockCfg) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithTracer(tracer), - integrationtest.WithWrappedRecorder(), + prov := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withTracer(tracer), + withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fixtureReqBody, "stream", tc.streaming) require.NoError(t, err) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -140,7 +139,7 @@ func TestTraceAnthropic(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderAnthropic), attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), attribute.Bool(tracing.IsBedrock, tc.bedrock), } @@ -219,21 +218,21 @@ func TestTraceAnthropicErr(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) var bedrockCfg *config.AWSBedrock if tc.bedrock { bedrockCfg = testBedrockCfg(upstream.URL) } - prov := provider.NewAnthropic(integrationtest.AnthropicCfg(upstream.URL, integrationtest.APIKey), bedrockCfg) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithTracer(tracer), - integrationtest.WithWrappedRecorder(), + prov := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withTracer(tracer), + withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := integrationtest.CreateAnthropicMessagesReq(t, ts.URL, reqBody) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -268,7 +267,7 @@ func TestTraceAnthropicErr(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderAnthropic), attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), attribute.Bool(tracing.IsBedrock, tc.bedrock), } @@ -297,7 +296,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: false, fixture: fixtures.AntSingleInjectedTool, providerFn: newAnthropicProvider, - createReqFn: integrationtest.CreateAnthropicMessagesReq, + createReqFn: createAnthropicMessagesReq, expectModel: "claude-sonnet-4-20250514", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -307,7 +306,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: true, fixture: fixtures.AntSingleInjectedTool, providerFn: newAnthropicProvider, - createReqFn: integrationtest.CreateAnthropicMessagesReq, + createReqFn: createAnthropicMessagesReq, expectModel: "claude-sonnet-4-20250514", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -318,7 +317,7 @@ func TestInjectedToolsTrace(t *testing.T) { bedrock: true, fixture: fixtures.AntSingleInjectedTool, providerFn: newBedrockProvider, - createReqFn: integrationtest.CreateAnthropicMessagesReq, + createReqFn: createAnthropicMessagesReq, expectModel: "beddel", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -329,7 +328,7 @@ func TestInjectedToolsTrace(t *testing.T) { bedrock: true, fixture: fixtures.AntSingleInjectedTool, providerFn: newBedrockProvider, - createReqFn: integrationtest.CreateAnthropicMessagesReq, + createReqFn: createAnthropicMessagesReq, expectModel: "beddel", expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, @@ -339,7 +338,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: false, fixture: fixtures.OaiChatSingleInjectedTool, providerFn: newOpenAIProvider, - createReqFn: integrationtest.CreateOpenAIChatCompletionsReq, + createReqFn: createOpenAIChatCompletionsReq, expectModel: "gpt-4.1", expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, @@ -349,7 +348,7 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: true, fixture: fixtures.OaiChatSingleInjectedTool, providerFn: newOpenAIProvider, - createReqFn: integrationtest.CreateOpenAIChatCompletionsReq, + createReqFn: createOpenAIChatCompletionsReq, expectModel: "gpt-4.1", expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, @@ -373,7 +372,7 @@ func TestInjectedToolsTrace(t *testing.T) { } recorderClient, mockMCP, resp := setupInjectedToolTest( - t, tc.fixture, tc.streaming, tc.providerFn, tracer, integrationtest.DefaultActorID, + t, tc.fixture, tc.streaming, tc.providerFn, tracer, defaultActorID, tc.createReqFn, validatorFn, ) defer resp.Body.Close() @@ -388,7 +387,7 @@ func TestInjectedToolsTrace(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, tc.expectProvider), attribute.String(tracing.Model, tc.expectModel), - attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.String(tracing.MCPInput, `{"owner":"admin"}`), attribute.String(tracing.MCPToolName, "coder_list_workspaces"), attribute.String(tracing.MCPServerName, tool.ServerName), @@ -418,7 +417,7 @@ func TestTraceOpenAI(t *testing.T) { fixture: fixtures.OaiChatSimple, streaming: true, expectPath: "/openai/v1/chat/completions", - reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, + reqFunc: createOpenAIChatCompletionsReq, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -433,7 +432,7 @@ func TestTraceOpenAI(t *testing.T) { { name: "trace_openai_chat_blocking", fixture: fixtures.OaiChatSimple, - reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, + reqFunc: createOpenAIChatCompletionsReq, streaming: false, expectPath: "/openai/v1/chat/completions", expect: []expectTrace{ @@ -452,7 +451,7 @@ func TestTraceOpenAI(t *testing.T) { fixture: fixtures.OaiResponsesStreamingSimple, streaming: true, expectPath: "/openai/v1/responses", - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -469,7 +468,7 @@ func TestTraceOpenAI(t *testing.T) { fixture: fixtures.OaiResponsesBlockingSimple, streaming: false, expectPath: "/openai/v1/responses", - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -494,11 +493,11 @@ func TestTraceOpenAI(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() fix := fixtures.Parse(t, tc.fixture) - upstream := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithTracer(tracer), - integrationtest.WithWrappedRecorder(), + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withTracer(tracer), + withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) @@ -526,7 +525,7 @@ func TestTraceOpenAI(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -550,7 +549,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiChatMidStreamError, streaming: true, expectPath: "/openai/v1/chat/completions", - reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, + reqFunc: createOpenAIChatCompletionsReq, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -567,7 +566,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiChatNonStreamError, streaming: false, expectPath: "/openai/v1/chat/completions", - reqFunc: integrationtest.CreateOpenAIChatCompletionsReq, + reqFunc: createOpenAIChatCompletionsReq, expectCode: http.StatusBadRequest, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -583,7 +582,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: true, fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, expectPath: "/openai/v1/responses", - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -600,7 +599,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, streaming: false, expectPath: "/openai/v1/responses", - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, // Fixture returns http 200 response with wrong body // responses forward received response as is so // expected code == 200 even though ProcessRequest @@ -622,7 +621,7 @@ func TestTraceOpenAIErr(t *testing.T) { allowOverflow: true, // 429 error causes retries expectPath: "/openai/v1/responses", - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectCode: http.StatusTooManyRequests, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -639,7 +638,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: false, expectPath: "/openai/v1/responses", - reqFunc: integrationtest.CreateOpenAIResponsesReq, + reqFunc: createOpenAIResponsesReq, expectCode: http.StatusUnauthorized, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -664,12 +663,12 @@ func TestTraceOpenAIErr(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) - mockAPI := integrationtest.NewMockUpstream(t, ctx, integrationtest.NewFixtureResponse(fix)) + mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix)) mockAPI.AllowOverflow = tc.allowOverflow - prov := provider.NewOpenAI(integrationtest.OpenAICfg(mockAPI.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, ctx, []aibridge.Provider{prov}, - integrationtest.WithTracer(tracer), - integrationtest.WithWrappedRecorder(), + prov := provider.NewOpenAI(openAICfg(mockAPI.URL, apiKey)) + ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + withTracer(tracer), + withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) @@ -698,7 +697,7 @@ func TestTraceOpenAIErr(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(tracing.InitiatorID, integrationtest.DefaultActorID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -711,17 +710,17 @@ func TestTracePassthrough(t *testing.T) { fix := fixtures.Parse(t, fixtures.OaiChatFallthrough) - upstream := integrationtest.NewMockUpstream(t, t.Context(), integrationtest.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - prov := provider.NewOpenAI(integrationtest.OpenAICfg(upstream.URL, integrationtest.APIKey)) - ts := integrationtest.NewBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, - integrationtest.WithTracer(tracer), - integrationtest.WithWrappedRecorder(), + prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) + ts := newBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, + withTracer(tracer), + withWrappedRecorder(), ) req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) @@ -753,7 +752,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() serverName := "serverName" - mockMCP := integrationtest.SetupMCPForTestWithName(t, serverName, tracer) + mockMCP := setupMCPForTestWithName(t, serverName, tracer) tool := mockMCP.ListTools()[0] require.Len(t, sr.Ended(), 3) diff --git a/internal/integrationtest/upstream.go b/internal/integrationtest/upstream.go index 356ec94..4f0ac7f 100644 --- a/internal/integrationtest/upstream.go +++ b/internal/integrationtest/upstream.go @@ -23,10 +23,10 @@ import ( "github.com/tidwall/gjson" ) -// UpstreamResponse defines a single response that MockUpstream will replay -// for one incoming request. Use [NewFixtureResponse] or [NewFixtureToolResponse] to +// upstreamResponse defines a single response that mockUpstream will replay +// for one incoming request. Use [newFixtureResponse] or [newFixtureToolResponse] to // construct one from a parsed txtar archive. -type UpstreamResponse struct { +type upstreamResponse struct { Streaming []byte // returned when the request has "stream": true. Blocking []byte // returned for non-streaming requests. @@ -35,11 +35,11 @@ type UpstreamResponse struct { OnRequest func(r *http.Request, body []byte) } -// NewFixtureResponse creates an UpstreamResponse from a parsed fixture archive. +// newFixtureResponse creates an upstreamResponse from a parsed fixture archive. // It reads whichever of 'streaming' and 'non-streaming' sections exist; // not every fixture has both (e.g. error fixtures may only define one). -func NewFixtureResponse(fix fixtures.Fixture) UpstreamResponse { - var resp UpstreamResponse +func newFixtureResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse if fix.Has(fixtures.SectionStreaming) { resp.Streaming = fix.Streaming() } @@ -49,11 +49,11 @@ func NewFixtureResponse(fix fixtures.Fixture) UpstreamResponse { return resp } -// NewFixtureToolResponse creates an UpstreamResponse from the tool-call fixture files. +// newFixtureToolResponse creates an upstreamResponse from the tool-call fixture files. // It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' // sections exist. -func NewFixtureToolResponse(fix fixtures.Fixture) UpstreamResponse { - var resp UpstreamResponse +func newFixtureToolResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse if fix.Has(fixtures.SectionStreamingToolCall) { resp.Streaming = fix.StreamingToolCall() } @@ -63,18 +63,18 @@ func NewFixtureToolResponse(fix fixtures.Fixture) UpstreamResponse { return resp } -// ReceivedRequest captures the details of a single request handled by MockUpstream. -type ReceivedRequest struct { +// receivedRequest captures the details of a single request handled by mockUpstream. +type receivedRequest struct { Method string Path string Header http.Header Body []byte } -// MockUpstream replays txtar fixture responses, validates incoming request +// mockUpstream replays txtar fixture responses, validates incoming request // bodies, and counts calls. It stands in for a real AI provider API // (Anthropic, OpenAI) during integration tests. -type MockUpstream struct { +type mockUpstream struct { *httptest.Server // Calls is incremented atomically on every request. @@ -92,31 +92,31 @@ type MockUpstream struct { AllowOverflow bool mu sync.Mutex - requests []ReceivedRequest + requests []receivedRequest t *testing.T - responses []UpstreamResponse + responses []upstreamResponse } -// ReceivedRequests returns a copy of all requests received so far. -func (ms *MockUpstream) ReceivedRequests() []ReceivedRequest { +// receivedRequests returns a copy of all requests received so far. +func (ms *mockUpstream) receivedRequests() []receivedRequest { ms.mu.Lock() defer ms.mu.Unlock() - return append([]ReceivedRequest(nil), ms.requests...) + return append([]receivedRequest(nil), ms.requests...) } -// NewMockUpstream creates a started httptest.Server that replays fixture +// newMockUpstream creates a started httptest.Server that replays fixture // responses. Responses are returned in order: first call → first response. // The test fails if the number of requests doesn't match the number of // responses (when AllowOverflow is not set, default). // -// srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) // simple -// srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) // multi-turn -func NewMockUpstream(t *testing.T, ctx context.Context, responses ...UpstreamResponse) *MockUpstream { +// srv := testutil.newMockUpstream(t, ctx, testutil.newFixtureResponse(fix)) // simple +// srv := testutil.newMockUpstream(t, ctx, testutil.newFixtureResponse(fix), testutil.newFixtureToolResponse(fix)) // multi-turn +func newMockUpstream(t *testing.T, ctx context.Context, responses ...upstreamResponse) *mockUpstream { t.Helper() - require.NotEmpty(t, responses, "at least one UpstreamResponse required") + require.NotEmpty(t, responses, "at least one upstreamResponse required") - ms := &MockUpstream{ + ms := &mockUpstream{ t: t, responses: responses, } @@ -141,7 +141,7 @@ func NewMockUpstream(t *testing.T, ctx context.Context, responses ...UpstreamRes return ms } -func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { +func (ms *mockUpstream) handle(w http.ResponseWriter, r *http.Request) { call := int(ms.Calls.Add(1) - 1) body, err := io.ReadAll(r.Body) @@ -149,7 +149,7 @@ func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { require.NoError(ms.t, err) ms.mu.Lock() - ms.requests = append(ms.requests, ReceivedRequest{ + ms.requests = append(ms.requests, receivedRequest{ Method: r.Method, Path: r.URL.Path, Header: r.Header.Clone(), @@ -186,7 +186,7 @@ func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(resp.Blocking) } -func (ms *MockUpstream) responseForCall(call int) UpstreamResponse { +func (ms *mockUpstream) responseForCall(call int) upstreamResponse { if call >= len(ms.responses) { if ms.AllowOverflow { return ms.responses[len(ms.responses)-1] @@ -203,7 +203,7 @@ func isStreaming(body []byte, urlPath string) bool { return gjson.GetBytes(body, "stream").Bool() || strings.HasSuffix(urlPath, "invoke-with-response-stream") } -func (ms *MockUpstream) writeSSE(w http.ResponseWriter, data []byte) { +func (ms *mockUpstream) writeSSE(w http.ResponseWriter, data []byte) { ms.t.Helper() w.Header().Set("Content-Type", "text/event-stream") @@ -237,7 +237,7 @@ func isRawHTTPResponse(data []byte) bool { // writeRawHTTPResponse parses data as a complete HTTP response and replays it, // copying the status code, headers, and body to w. This supports error fixtures // that contain full HTTP responses (e.g. "HTTP/2.0 400 Bad Request\r\n..."). -func (ms *MockUpstream) writeRawHTTPResponse(w http.ResponseWriter, r *http.Request, data []byte) { +func (ms *mockUpstream) writeRawHTTPResponse(w http.ResponseWriter, r *http.Request, data []byte) { ms.t.Helper() resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), r) From 23eb779622c15c77a167afae251205b2ce697d22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 17:14:10 +0000 Subject: [PATCH 07/32] refactor: split bridge_test.go into per-interceptor test files --- internal/integrationtest/bedrock_test.go | 136 +++ internal/integrationtest/bridge_test.go | 816 ------------------ .../integrationtest/chatcompletions_test.go | 320 +++++++ internal/integrationtest/messages_test.go | 418 +++++++++ 4 files changed, 874 insertions(+), 816 deletions(-) create mode 100644 internal/integrationtest/bedrock_test.go create mode 100644 internal/integrationtest/chatcompletions_test.go create mode 100644 internal/integrationtest/messages_test.go diff --git a/internal/integrationtest/bedrock_test.go b/internal/integrationtest/bedrock_test.go new file mode 100644 index 0000000..3709185 --- /dev/null +++ b/internal/integrationtest/bedrock_test.go @@ -0,0 +1,136 @@ +package integrationtest + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/provider" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. +func testBedrockCfg(url string) *config.AWSBedrock { + return &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + BaseURL: url, + } +} + +func newBedrockProvider(addr string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) +} + +func TestAWSBedrockIntegration(t *testing.T) { + t.Parallel() + + t.Run("invalid config", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Invalid bedrock config - missing region & base url + bedrockCfg := &config.AWSBedrock{ + Region: "", + AccessKey: "test-key", + AccessKeySecret: "test-secret", + Model: "test-model", + SmallFastModel: "test-haiku", + } + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)}, + withLogger(newLogger(t)), + ) + + req := createAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "create anthropic client") + require.Contains(t, string(body), "region or base url required") + }) + + t.Run("/v1/messages", func(t *testing.T) { + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. + bedrockCfg := &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "danthropic", // This model should override the request's given one. + SmallFastModel: "danthropic-mini", // Unused but needed for validation. + BaseURL: upstream.URL, // Use the mock server. + } + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, + withLogger(newLogger(t)), + ) + + // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. + // We override the AWS Bedrock client to route requests through our mock server. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // For streaming responses, consume the body to allow the stream to complete. + if streaming { + // Read the streaming response. + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + } + + // Verify that Bedrock-specific model name was used in the request to the mock server + // and the interception data. + received := upstream.receivedRequests() + require.Len(t, received, 1) + + // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" + // from the JSON body and encodes them in the URL path. + // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 + pathParts := strings.Split(received[0].Path, "/") + require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) + require.Equal(t, bedrockCfg.Model, pathParts[2]) + require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") + require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") + + interceptions := ts.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + require.Equal(t, interceptions[0].Model, bedrockCfg.Model) + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 5ebbef0..cbc1ee8 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -13,16 +13,13 @@ import ( "testing" "time" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" - "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/google/uuid" @@ -36,18 +33,6 @@ import ( "go.uber.org/goleak" ) -// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. -func testBedrockCfg(url string) *config.AWSBedrock { - return &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - BaseURL: url, - } -} - type ( providerFunc func(addr string) aibridge.Provider createRequestFunc func(*testing.T, string, []byte) *http.Request @@ -61,376 +46,10 @@ func newOpenAIProvider(addr string) aibridge.Provider { return provider.NewOpenAI(openAICfg(addr, apiKey)) } -func newBedrockProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func TestAnthropicMessages(t *testing.T) { - t.Parallel() - - t.Run("single builtin tool", func(t *testing.T) { - t.Parallel() - - cases := []struct { - streaming bool - expectedInputTokens int - expectedOutputTokens int - expectedToolCallID string - }{ - { - streaming: true, - expectedInputTokens: 2, - expectedOutputTokens: 66, - expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", - }, - { - streaming: false, - expectedInputTokens: 5, - expectedOutputTokens: 84, - expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", - }, - } - - for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, - ) - - // Make API call to aibridge for Anthropic /v1/messages - reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) - require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - - // Response-specific checks. - if tc.streaming { - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - - // Ensure the message starts and completes, at a minimum. - assert.Contains(t, sp.AllEvents(), "message_start") - assert.Contains(t, sp.AllEvents(), "message_stop") - } - - expectedTokenRecordings := 1 - if tc.streaming { - // One for message_start, one for message_delta. - expectedTokenRecordings = 2 - } - tokenUsages := ts.Recorder.RecordedTokenUsages() - require.Len(t, tokenUsages, expectedTokenRecordings) - - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - - toolUsages := ts.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, "Read", toolUsages[0].Tool) - assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) - require.IsType(t, json.RawMessage{}, toolUsages[0].Args) - var args map[string]any - require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) - require.Contains(t, args, "file_path") - assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - - promptUsages := ts.Recorder.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} - -func TestAWSBedrockIntegration(t *testing.T) { - t.Parallel() - - t.Run("invalid config", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Invalid bedrock config - missing region & base url - bedrockCfg := &config.AWSBedrock{ - Region: "", - AccessKey: "test-key", - AccessKeySecret: "test-secret", - Model: "test-model", - SmallFastModel: "test-haiku", - } - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)}, - withLogger(newLogger(t, &slogtest.Options{IgnoreErrors: true})), - ) - - req := createAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - require.Equal(t, http.StatusInternalServerError, resp.StatusCode) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "create anthropic client") - require.Contains(t, string(body), "region or base url required") - }) - - t.Run("/v1/messages", func(t *testing.T) { - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. - bedrockCfg := &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "danthropic", // This model should override the request's given one. - SmallFastModel: "danthropic-mini", // Unused but needed for validation. - BaseURL: upstream.URL, // Use the mock server. - } - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, - withLogger(newLogger(t, &slogtest.Options{IgnoreErrors: true})), - ) - - // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. - // We override the AWS Bedrock client to route requests through our mock server. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // For streaming responses, consume the body to allow the stream to complete. - if streaming { - // Read the streaming response. - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - } - - // Verify that Bedrock-specific model name was used in the request to the mock server - // and the interception data. - received := upstream.receivedRequests() - require.Len(t, received, 1) - - // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" - // from the JSON body and encodes them in the URL path. - // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 - pathParts := strings.Split(received[0].Path, "/") - require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) - require.Equal(t, bedrockCfg.Model, pathParts[2]) - require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") - require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - - interceptions := ts.Recorder.RecordedInterceptions() - require.Len(t, interceptions, 1) - require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} - -func TestOpenAIChatCompletions(t *testing.T) { - t.Parallel() - - t.Run("single builtin tool", func(t *testing.T) { - t.Parallel() - - cases := []struct { - streaming bool - expectedInputTokens, expectedOutputTokens int - expectedToolCallID string - }{ - { - streaming: true, - expectedInputTokens: 60, - expectedOutputTokens: 15, - expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", - }, - { - streaming: false, - expectedInputTokens: 60, - expectedOutputTokens: 15, - expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", - }, - } - - for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, - ) - - // Make API call to aibridge for OpenAI /v1/chat/completions - reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) - require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - - // Response-specific checks. - if tc.streaming { - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - - // OpenAI sends all events under the same type. - messageEvents := sp.MessageEvents() - assert.NotEmpty(t, messageEvents) - - // OpenAI streaming ends with [DONE] - lastEvent := messageEvents[len(messageEvents)-1] - assert.Equal(t, "[DONE]", lastEvent.Data) - } - - tokenUsages := ts.Recorder.RecordedTokenUsages() - require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - - toolUsages := ts.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, "read_file", toolUsages[0].Tool) - assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) - require.IsType(t, map[string]any{}, toolUsages[0].Args) - require.Contains(t, toolUsages[0].Args, "path") - assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - - promptUsages := ts.Recorder.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) - - t.Run("streaming injected tool call edge cases", func(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - fixture []byte - expectedArgs map[string]any - }{ - { - name: "tool call no preamble", - fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, - expectedArgs: map[string]any{"owner": "me"}, - }, - { - name: "tool call with non-zero index", - fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, - expectedArgs: nil, // No arguments in this fixture - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) - - // Setup MCP proxies with the tool from the fixture - mockMCP := setupMCPForTest(t, defaultTracer) - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, - withMCP(mockMCP), - ) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) - require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - - // Verify SSE headers are sent correctly - require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) - require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) - require.Equal(t, "keep-alive", resp.Header.Get("Connection")) - - // Consume the full response body to ensure the interception completes - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() - - // Verify the MCP tool was actually invoked - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1, "expected MCP tool to be invoked") - - // Verify tool was invoked with the expected args (if specified) - if tc.expectedArgs != nil { - expected, err := json.Marshal(tc.expectedArgs) - require.NoError(t, err) - actual, err := json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - } - - // Verify tool usage was recorded - toolUsages := ts.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, mockToolName, toolUsages[0].Tool) - - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} - func TestSimple(t *testing.T) { t.Parallel() @@ -836,262 +455,6 @@ func TestFallthrough(t *testing.T) { } } -func TestAnthropicInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, defaultTracer, defaultActorID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *anthropic.ContentBlockUnion - message anthropic.Message - ) - if streaming { - // Parse the response stream. - decoder := ssestream.NewDecoder(resp) - stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) - for stream.Next() { - event := stream.Current() - require.NoError(t, message.Accumulate(event), "accumulate event") - } - - require.NoError(t, stream.Err(), "stream error") - require.Len(t, message.Content, 2) - - content = &message.Content[1] - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - require.GreaterOrEqual(t, len(message.Content), 1) - - content = &message.Content[0] - } - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Text, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. - - // Check the token usage from the client's perspective. - // - // We overwrite the final message_delta which is relayed to the client to include the - // accumulated tokens but currently the SDK only supports accumulating output tokens - // for message_delta events. - // - // For non-streaming requests the token usage is also overwritten and should be faithfully - // represented in the response. - // - // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/message.go#L2619-L2622 - if !streaming { - assert.EqualValues(t, 15308, message.Usage.InputTokens) - } - assert.EqualValues(t, 204, message.Usage.OutputTokens) - - // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) - assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) - - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} - -func TestOpenAIInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, defaultTracer, defaultActorID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *openai.ChatCompletionChoice - message openai.ChatCompletion - ) - if streaming { - // Parse the response stream. - decoder := oaissestream.NewDecoder(resp) - stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) - var acc openai.ChatCompletionAccumulator - detectedToolCalls := make(map[string]struct{}) - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if len(chunk.Choices) == 0 { - continue - } - - for _, c := range chunk.Choices { - if len(c.Delta.ToolCalls) == 0 { - continue - } - - for _, t := range c.Delta.ToolCalls { - if t.Function.Name == "" { - continue - } - - detectedToolCalls[t.Function.Name] = struct{}{} - } - } - } - - // Verify that no injected tool call events (or partials thereof) were sent to the client. - require.Len(t, detectedToolCalls, 0) - - message = acc.ChatCompletion - require.NoError(t, stream.Err(), "stream error") - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - - // Verify that no injected tools were sent to the client. - require.GreaterOrEqual(t, len(message.Choices), 1) - require.Len(t, message.Choices[0].Message.ToolCalls, 0) - } - - require.GreaterOrEqual(t, len(message.Choices), 1) - content = &message.Choices[0] - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. - - // Check the token usage from the client's perspective. - // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. - // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. - // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) - assert.EqualValues(t, 105, message.Usage.CompletionTokens) - - // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) - require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) - - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} - -// anthropicToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_use and user's tool_result messages -// appended by the inner agentic loop. If the raw payload is not kept in sync with -// the structured messages, the second request will be identical to the first. -func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { - t.Helper() - - return func(_ *http.Request, raw []byte) { - messages := gjson.GetBytes(raw, "messages").Array() - - // After the agentic loop the messages must contain at minimum: - // [0] original user message - // [N-2] assistant message with tool_use content block - // [N-1] user message with tool_result content block - require.GreaterOrEqual(t, len(messages), 3, - "second upstream request must contain the original message, assistant tool_use, and user tool_result") - - assistantMsg := messages[len(messages)-2] - require.Equal(t, "assistant", assistantMsg.Get("role").Str, - "penultimate message must be from the assistant") - var hasToolUse bool - for _, block := range assistantMsg.Get("content").Array() { - if block.Get("type").Str == "tool_use" { - hasToolUse = true - break - } - } - require.True(t, hasToolUse, "assistant message must contain a tool_use content block") - - toolResultMsg := messages[len(messages)-1] - require.Equal(t, "user", toolResultMsg.Get("role").Str, - "last message must be a user message carrying the tool_result") - var hasToolResult bool - for _, block := range toolResultMsg.Get("content").Array() { - if block.Get("type").Str == "tool_result" { - hasToolResult = true - break - } - } - require.True(t, hasToolResult, "user message must contain a tool_result content block") - } -} - -// openaiChatToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_calls and a role=tool result message -// appended by the inner agentic loop. -func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { - t.Helper() - - return func(_ *http.Request, raw []byte) { - messages := gjson.GetBytes(raw, "messages").Array() - - // After the agentic loop the messages must contain at minimum: - // [0] original user message - // [N-2] assistant message with tool_calls array - // [N-1] message with role=tool - require.GreaterOrEqual(t, len(messages), 3, - "second upstream request must contain the original message, assistant tool_calls, and tool result") - - assistantMsg := messages[len(messages)-2] - require.Equal(t, "assistant", assistantMsg.Get("role").Str, - "penultimate message must be from the assistant") - require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), - "assistant message must contain a tool_calls array") - - toolResultMsg := messages[len(messages)-1] - require.Equal(t, "tool", toolResultMsg.Get("role").Str, - "last message must have role=tool") - require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, - "tool result message must have a tool_call_id") - } -} - // setupInjectedToolTest abstracts common setup required for injected-tool integration tests. func setupInjectedToolTest( t *testing.T, @@ -1374,183 +737,6 @@ func TestStableRequestEncoding(t *testing.T) { } } -// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is -// correctly disabled based on the tool_choice parameter in the request. -// See https://github.com/coder/aibridge/issues/2 -func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { - t.Parallel() - - var ( - toolChoiceAuto = string(constant.ValueOf[constant.Auto]()) - toolChoiceAny = string(constant.ValueOf[constant.Any]()) - toolChoiceNone = string(constant.ValueOf[constant.None]()) - toolChoiceTool = string(constant.ValueOf[constant.Tool]()) - ) - - cases := []struct { - name string - toolChoice any // nil, or map with "type" key. - withInjectedTools bool - expectDisableParallel bool - expectToolChoiceTypeInRequest string - }{ - // With injected tools - disable_parallel_tool_use should be set. - { - name: "with injected tools: no tool_choice defined defaults to auto", - toolChoice: nil, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceAuto, - }, - { - name: "with injected tools: tool_choice auto", - toolChoice: map[string]any{"type": toolChoiceAuto}, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceAuto, - }, - { - name: "with injected tools: tool_choice any", - toolChoice: map[string]any{"type": toolChoiceAny}, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceAny, - }, - { - name: "with injected tools: tool_choice tool", - toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceTool, - }, - { - name: "with injected tools: tool_choice none", - toolChoice: map[string]any{"type": toolChoiceNone}, - withInjectedTools: true, - expectDisableParallel: false, - expectToolChoiceTypeInRequest: toolChoiceNone, - }, - // Without injected tools - disable_parallel_tool_use should NOT be set. - { - name: "without injected tools: tool_choice auto", - toolChoice: map[string]any{"type": toolChoiceAuto}, - withInjectedTools: false, - expectDisableParallel: false, - expectToolChoiceTypeInRequest: toolChoiceAuto, - }, - { - name: "without injected tools: tool_choice any", - toolChoice: map[string]any{"type": toolChoiceAny}, - withInjectedTools: false, - expectDisableParallel: false, - expectToolChoiceTypeInRequest: toolChoiceAny, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Setup MCP tools conditionally. - var mcpMgr mcp.ServerProxier - if tc.withInjectedTools { - mcpMgr = setupMCPForTest(t, defaultTracer) - } else { - mcpMgr = newNoopMCPManager() - } - - fix := fixtures.Parse(t, fixtures.AntSimple) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, - withMCP(mcpMgr), - ) - - // Prepare request body with tool_choice set. - reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) - require.NoError(t, err) - - req := createAnthropicMessagesReq(t, ts.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() - - // Verify tool_choice in the upstream request. - received := upstream.receivedRequests() - require.Len(t, received, 1) - var receivedRequest map[string]any - require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) - toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) - require.True(t, ok, "expected tool_choice in upstream request") - - // Verify the type matches expectation. - assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"]) - - // Verify name is preserved for tool_choice=tool. - if tc.expectToolChoiceTypeInRequest == toolChoiceTool { - assert.Equal(t, "some_tool", toolChoice["name"]) - } - - // Verify disable_parallel_tool_use based on expectations. - // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use - disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool) - - if tc.expectDisableParallel { - require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice") - assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true") - } else { - assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set") - } - }) - } -} - -func TestThinkingAdaptiveIsPreserved(t *testing.T) { - t.Parallel() - - fix := fixtures.Parse(t, fixtures.AntSimple) - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Create a mock server that captures the request body sent upstream. - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, - ) - - // Inject adaptive thinking into the fixture request. - reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) - require.NoError(t, err) - reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) - require.NoError(t, err) - - req := createAnthropicMessagesReq(t, ts.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) - _ = resp.Body.Close() - - // Verify the thinking field was preserved in the upstream request. - received := upstream.receivedRequests() - require.Len(t, received, 1) - assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) - }) - } -} - func TestEnvironmentDoNotLeak(t *testing.T) { // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution. @@ -1777,5 +963,3 @@ func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { } return total } - - diff --git a/internal/integrationtest/chatcompletions_test.go b/internal/integrationtest/chatcompletions_test.go new file mode 100644 index 0000000..f78103d --- /dev/null +++ b/internal/integrationtest/chatcompletions_test.go @@ -0,0 +1,320 @@ +package integrationtest + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/provider" + "github.com/openai/openai-go/v3" + oaissestream "github.com/openai/openai-go/v3/packages/ssestream" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +func TestOpenAIChatCompletions(t *testing.T) { + t.Parallel() + + t.Run("single builtin tool", func(t *testing.T) { + t.Parallel() + + cases := []struct { + streaming bool + expectedInputTokens, expectedOutputTokens int + expectedToolCallID string + }{ + { + streaming: true, + expectedInputTokens: 60, + expectedOutputTokens: 15, + expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", + }, + { + streaming: false, + expectedInputTokens: 60, + expectedOutputTokens: 15, + expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", + }, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, + ) + + // Make API call to aibridge for OpenAI /v1/chat/completions + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + + // Response-specific checks. + if tc.streaming { + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) + + // OpenAI sends all events under the same type. + messageEvents := sp.MessageEvents() + assert.NotEmpty(t, messageEvents) + + // OpenAI streaming ends with [DONE] + lastEvent := messageEvents[len(messageEvents)-1] + assert.Equal(t, "[DONE]", lastEvent.Data) + } + + tokenUsages := ts.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, 1) + assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + + toolUsages := ts.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "read_file", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) + require.IsType(t, map[string]any{}, toolUsages[0].Args) + require.Contains(t, toolUsages[0].Args, "path") + assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) + + promptUsages := ts.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) + + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) + + t.Run("streaming injected tool call edge cases", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + fixture []byte + expectedArgs map[string]any + }{ + { + name: "tool call no preamble", + fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, + expectedArgs: map[string]any{"owner": "me"}, + }, + { + name: "tool call with non-zero index", + fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, + expectedArgs: nil, // No arguments in this fixture + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) + + // Setup MCP proxies with the tool from the fixture + mockMCP := setupMCPForTest(t, defaultTracer) + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, + withMCP(mockMCP), + ) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) + require.NoError(t, err) + req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify SSE headers are sent correctly + require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + require.Equal(t, "keep-alive", resp.Header.Get("Connection")) + + // Consume the full response body to ensure the interception completes + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + + // Verify the MCP tool was actually invoked + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1, "expected MCP tool to be invoked") + + // Verify tool was invoked with the expected args (if specified) + if tc.expectedArgs != nil { + expected, err := json.Marshal(tc.expectedArgs) + require.NoError(t, err) + actual, err := json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + } + + // Verify tool usage was recorded + toolUsages := ts.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, mockToolName, toolUsages[0].Tool) + + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} + +func TestOpenAIInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, defaultTracer, defaultActorID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + + // Ensure expected tool was invoked with expected input. + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) + expected, err := json.Marshal(map[string]any{"owner": "admin"}) + require.NoError(t, err) + actual, err := json.Marshal(toolUsages[0].Args) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + + var ( + content *openai.ChatCompletionChoice + message openai.ChatCompletion + ) + if streaming { + // Parse the response stream. + decoder := oaissestream.NewDecoder(resp) + stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) + var acc openai.ChatCompletionAccumulator + detectedToolCalls := make(map[string]struct{}) + for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + + if len(chunk.Choices) == 0 { + continue + } + + for _, c := range chunk.Choices { + if len(c.Delta.ToolCalls) == 0 { + continue + } + + for _, t := range c.Delta.ToolCalls { + if t.Function.Name == "" { + continue + } + + detectedToolCalls[t.Function.Name] = struct{}{} + } + } + } + + // Verify that no injected tool call events (or partials thereof) were sent to the client. + require.Len(t, detectedToolCalls, 0) + + message = acc.ChatCompletion + require.NoError(t, stream.Err(), "stream error") + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + + // Verify that no injected tools were sent to the client. + require.GreaterOrEqual(t, len(message.Choices), 1) + require.Len(t, message.Choices[0].Message.ToolCalls, 0) + } + + require.GreaterOrEqual(t, len(message.Choices), 1) + content = &message.Choices[0] + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. + // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. + // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) + assert.EqualValues(t, 105, message.Usage.CompletionTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + tokenUsages := recorderClient.RecordedTokenUsages() + require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) + require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) + + // Ensure we received exactly one prompt. + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} + +// openaiChatToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_calls and a role=tool result message +// appended by the inner agentic loop. +func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { + t.Helper() + + return func(_ *http.Request, raw []byte) { + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_calls array + // [N-1] message with role=tool + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_calls, and tool result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), + "assistant message must contain a tool_calls array") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "tool", toolResultMsg.Get("role").Str, + "last message must have role=tool") + require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, + "tool result message must have a tool_call_id") + } +} diff --git a/internal/integrationtest/messages_test.go b/internal/integrationtest/messages_test.go new file mode 100644 index 0000000..91c3ddd --- /dev/null +++ b/internal/integrationtest/messages_test.go @@ -0,0 +1,418 @@ +package integrationtest + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/coder/aibridge" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/provider" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +func TestAnthropicMessages(t *testing.T) { + t.Parallel() + + t.Run("single builtin tool", func(t *testing.T) { + t.Parallel() + + cases := []struct { + streaming bool + expectedInputTokens int + expectedOutputTokens int + expectedToolCallID string + }{ + { + streaming: true, + expectedInputTokens: 2, + expectedOutputTokens: 66, + expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", + }, + { + streaming: false, + expectedInputTokens: 5, + expectedOutputTokens: 84, + expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", + }, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, + ) + + // Make API call to aibridge for Anthropic /v1/messages + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + req := createAnthropicMessagesReq(t, ts.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + + // Response-specific checks. + if tc.streaming { + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) + + // Ensure the message starts and completes, at a minimum. + assert.Contains(t, sp.AllEvents(), "message_start") + assert.Contains(t, sp.AllEvents(), "message_stop") + } + + expectedTokenRecordings := 1 + if tc.streaming { + // One for message_start, one for message_delta. + expectedTokenRecordings = 2 + } + tokenUsages := ts.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, expectedTokenRecordings) + + assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + + toolUsages := ts.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "Read", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) + require.IsType(t, json.RawMessage{}, toolUsages[0].Args) + var args map[string]any + require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) + require.Contains(t, args, "file_path") + assert.Equal(t, "/tmp/blah/foo", args["file_path"]) + + promptUsages := ts.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "read the foo file", promptUsages[0].Prompt) + + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} + +func TestAnthropicInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, defaultTracer, defaultActorID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) + + // Ensure expected tool was invoked with expected input. + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) + expected, err := json.Marshal(map[string]any{"owner": "admin"}) + require.NoError(t, err) + actual, err := json.Marshal(toolUsages[0].Args) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + + var ( + content *anthropic.ContentBlockUnion + message anthropic.Message + ) + if streaming { + // Parse the response stream. + decoder := ssestream.NewDecoder(resp) + stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) + for stream.Next() { + event := stream.Current() + require.NoError(t, message.Accumulate(event), "accumulate event") + } + + require.NoError(t, stream.Err(), "stream error") + require.Len(t, message.Content, 2) + + content = &message.Content[1] + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + require.GreaterOrEqual(t, len(message.Content), 1) + + content = &message.Content[0] + } + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Text, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // + // We overwrite the final message_delta which is relayed to the client to include the + // accumulated tokens but currently the SDK only supports accumulating output tokens + // for message_delta events. + // + // For non-streaming requests the token usage is also overwritten and should be faithfully + // represented in the response. + // + // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/message.go#L2619-L2622 + if !streaming { + assert.EqualValues(t, 15308, message.Usage.InputTokens) + } + assert.EqualValues(t, 204, message.Usage.OutputTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + tokenUsages := recorderClient.RecordedTokenUsages() + assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) + assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) + + // Ensure we received exactly one prompt. + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} + +// anthropicToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_use and user's tool_result messages +// appended by the inner agentic loop. If the raw payload is not kept in sync with +// the structured messages, the second request will be identical to the first. +func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { + t.Helper() + + return func(_ *http.Request, raw []byte) { + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_use content block + // [N-1] user message with tool_result content block + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_use, and user tool_result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + var hasToolUse bool + for _, block := range assistantMsg.Get("content").Array() { + if block.Get("type").Str == "tool_use" { + hasToolUse = true + break + } + } + require.True(t, hasToolUse, "assistant message must contain a tool_use content block") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "user", toolResultMsg.Get("role").Str, + "last message must be a user message carrying the tool_result") + var hasToolResult bool + for _, block := range toolResultMsg.Get("content").Array() { + if block.Get("type").Str == "tool_result" { + hasToolResult = true + break + } + } + require.True(t, hasToolResult, "user message must contain a tool_result content block") + } +} + +// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is +// correctly disabled based on the tool_choice parameter in the request. +// See https://github.com/coder/aibridge/issues/2 +func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { + t.Parallel() + + var ( + toolChoiceAuto = string(constant.ValueOf[constant.Auto]()) + toolChoiceAny = string(constant.ValueOf[constant.Any]()) + toolChoiceNone = string(constant.ValueOf[constant.None]()) + toolChoiceTool = string(constant.ValueOf[constant.Tool]()) + ) + + cases := []struct { + name string + toolChoice any // nil, or map with "type" key. + withInjectedTools bool + expectDisableParallel bool + expectToolChoiceTypeInRequest string + }{ + // With injected tools - disable_parallel_tool_use should be set. + { + name: "with injected tools: no tool_choice defined defaults to auto", + toolChoice: nil, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected tools: tool_choice auto", + toolChoice: map[string]any{"type": toolChoiceAuto}, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected tools: tool_choice any", + toolChoice: map[string]any{"type": toolChoiceAny}, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + { + name: "with injected tools: tool_choice tool", + toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceTool, + }, + { + name: "with injected tools: tool_choice none", + toolChoice: map[string]any{"type": toolChoiceNone}, + withInjectedTools: true, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceNone, + }, + // Without injected tools - disable_parallel_tool_use should NOT be set. + { + name: "without injected tools: tool_choice auto", + toolChoice: map[string]any{"type": toolChoiceAuto}, + withInjectedTools: false, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "without injected tools: tool_choice any", + toolChoice: map[string]any{"type": toolChoiceAny}, + withInjectedTools: false, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup MCP tools conditionally. + var mcpMgr mcp.ServerProxier + if tc.withInjectedTools { + mcpMgr = setupMCPForTest(t, defaultTracer) + } else { + mcpMgr = newNoopMCPManager() + } + + fix := fixtures.Parse(t, fixtures.AntSimple) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, + withMCP(mcpMgr), + ) + + // Prepare request body with tool_choice set. + reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) + require.NoError(t, err) + + req := createAnthropicMessagesReq(t, ts.URL, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + // Verify tool_choice in the upstream request. + received := upstream.receivedRequests() + require.Len(t, received, 1) + var receivedRequest map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) + toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) + require.True(t, ok, "expected tool_choice in upstream request") + + // Verify the type matches expectation. + assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"]) + + // Verify name is preserved for tool_choice=tool. + if tc.expectToolChoiceTypeInRequest == toolChoiceTool { + assert.Equal(t, "some_tool", toolChoice["name"]) + } + + // Verify disable_parallel_tool_use based on expectations. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use + disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool) + + if tc.expectDisableParallel { + require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice") + assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true") + } else { + assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set") + } + }) + } +} + +func TestThinkingAdaptiveIsPreserved(t *testing.T) { + t.Parallel() + + fix := fixtures.Parse(t, fixtures.AntSimple) + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Create a mock server that captures the request body sent upstream. + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, + []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, + ) + + // Inject adaptive thinking into the fixture request. + reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) + require.NoError(t, err) + reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) + require.NoError(t, err) + + req := createAnthropicMessagesReq(t, ts.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + + // Verify the thinking field was preserved in the upstream request. + received := upstream.receivedRequests() + require.Len(t, received, 1) + assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) + }) + } +} From 11f6f37f608ca0a3f0fbe6fdc39c750b7b03f12b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 17:14:55 +0000 Subject: [PATCH 08/32] fixup: apply unexport renames to remaining helper files --- internal/integrationtest/apidump_test.go | 16 ---------- internal/integrationtest/bridge.go | 11 ++----- internal/integrationtest/mockmcp.go | 4 ++- internal/integrationtest/requests.go | 37 ++++++++++++++---------- internal/integrationtest/trace_test.go | 2 -- internal/integrationtest/upstream.go | 4 +-- 6 files changed, 29 insertions(+), 45 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 52d8e9f..4c8503b 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -21,22 +21,6 @@ import ( "github.com/stretchr/testify/require" ) -func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { - return config.OpenAI{ - BaseURL: url, - Key: key, - APIDumpDir: dumpDir, - } -} - -func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { - return config.Anthropic{ - BaseURL: url, - Key: key, - APIDumpDir: dumpDir, - } -} - func TestAPIDump(t *testing.T) { t.Parallel() diff --git a/internal/integrationtest/bridge.go b/internal/integrationtest/bridge.go index 62f2755..845c839 100644 --- a/internal/integrationtest/bridge.go +++ b/internal/integrationtest/bridge.go @@ -26,16 +26,9 @@ const defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" var defaultTracer = otel.Tracer("integrationtest") // newLogger creates a test logger at Debug level. -// Eliminates the repeated slogtest.Make(t, &slogtest.Options{...}).Leveled(slog.LevelDebug) pattern. -func newLogger(t *testing.T, opts ...*slogtest.Options) slog.Logger { +func newLogger(t *testing.T) slog.Logger { t.Helper() - var o *slogtest.Options - if len(opts) > 0 { - o = opts[0] - } else { - o = &slogtest.Options{} - } - return slogtest.Make(t, o).Leveled(slog.LevelDebug) + return slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) } // bridgeTestServer wraps an httptest.Server running a RequestBridge. diff --git a/internal/integrationtest/mockmcp.go b/internal/integrationtest/mockmcp.go index 401ddea..eba25dd 100644 --- a/internal/integrationtest/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -59,7 +59,9 @@ func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mo // which can break when httptest.Server calls CloseIdleConnections in parallel // resulting in error `init MCP client: failed to send initialized notification: failed to send request: failed to send request: Post "http://127.0.0.1:43843": net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called` // https://github.com/golang/go/blob/44ec057a3e89482cf775f5eaaf03b0b5fcab1fa4/src/net/http/httptest/server.go#L268 - httpClient := &http.Client{Transport: &http.Transport{}} + httpTransport := &http.Transport{} + t.Cleanup(httpTransport.CloseIdleConnections) + httpClient := &http.Client{Transport: httpTransport} proxy, err := mcp.NewStreamableHTTPServerProxy(name, mcpSrv.URL, nil, nil, nil, logger, tracer, transport.WithHTTPBasicClient(httpClient)) require.NoError(t, err) diff --git a/internal/integrationtest/requests.go b/internal/integrationtest/requests.go index 738ac00..1ae4679 100644 --- a/internal/integrationtest/requests.go +++ b/internal/integrationtest/requests.go @@ -12,36 +12,31 @@ import ( // apiKey is the default API key used across integration tests. const apiKey = "api-key" -// createAnthropicMessagesReq builds an HTTP request targeting the Anthropic messages endpoint. -func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { +func createJSONReq(t *testing.T, method, baseURL, path string, input []byte) *http.Request { t.Helper() - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/anthropic/v1/messages", bytes.NewReader(input)) + req, err := http.NewRequestWithContext(t.Context(), method, baseURL+path, bytes.NewReader(input)) require.NoError(t, err) req.Header.Set("Content-Type", "application/json") - return req } +// createAnthropicMessagesReq builds an HTTP request targeting the Anthropic messages endpoint. +func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { + t.Helper() + return createJSONReq(t, http.MethodPost, baseURL, "/anthropic/v1/messages", input) +} + // createOpenAIChatCompletionsReq builds an HTTP request targeting the OpenAI chat completions endpoint. func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/chat/completions", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req + return createJSONReq(t, http.MethodPost, baseURL, "/openai/v1/chat/completions", input) } // createOpenAIResponsesReq builds an HTTP request targeting the OpenAI responses endpoint. func createOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/responses", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - return req + return createJSONReq(t, http.MethodPost, baseURL, "/openai/v1/responses", input) } // openAICfg creates a minimal OpenAI config for testing. @@ -52,6 +47,12 @@ func openAICfg(url, key string) config.OpenAI { } } +func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { + cfg := openAICfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + // anthropicCfg creates a minimal Anthropic config for testing. func anthropicCfg(url, key string) config.Anthropic { return config.Anthropic{ @@ -59,3 +60,9 @@ func anthropicCfg(url, key string) config.Anthropic { Key: key, } } + +func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { + cfg := anthropicCfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index eb71a6d..c7b8397 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -794,5 +794,3 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e } } } - - diff --git a/internal/integrationtest/upstream.go b/internal/integrationtest/upstream.go index 4f0ac7f..a658b05 100644 --- a/internal/integrationtest/upstream.go +++ b/internal/integrationtest/upstream.go @@ -110,8 +110,8 @@ func (ms *mockUpstream) receivedRequests() []receivedRequest { // The test fails if the number of requests doesn't match the number of // responses (when AllowOverflow is not set, default). // -// srv := testutil.newMockUpstream(t, ctx, testutil.newFixtureResponse(fix)) // simple -// srv := testutil.newMockUpstream(t, ctx, testutil.newFixtureResponse(fix), testutil.newFixtureToolResponse(fix)) // multi-turn +// srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // simple +// srv := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn func newMockUpstream(t *testing.T, ctx context.Context, responses ...upstreamResponse) *mockUpstream { t.Helper() require.NotEmpty(t, responses, "at least one upstreamResponse required") From 82ea36884e7da06464ed2566b4ae15c0583c5724 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 5 Mar 2026 18:33:16 +0000 Subject: [PATCH 09/32] remove providerFunc --- internal/integrationtest/apidump_test.go | 46 +-- internal/integrationtest/bedrock_test.go | 29 +- internal/integrationtest/bridge.go | 184 +++++++++++- internal/integrationtest/bridge_test.go | 264 +++++------------- .../integrationtest/chatcompletions_test.go | 14 +- .../integrationtest/circuit_breaker_test.go | 52 ++-- internal/integrationtest/messages_test.go | 20 +- internal/integrationtest/metrics_test.go | 52 ++-- internal/integrationtest/requests.go | 32 --- internal/integrationtest/responses_test.go | 31 +- internal/integrationtest/trace_test.go | 124 ++++---- 11 files changed, 395 insertions(+), 453 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 4c8503b..fa441d6 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -25,34 +25,34 @@ func TestAPIDump(t *testing.T) { t.Parallel() cases := []struct { - name string - fixture []byte - providersFunc func(addr, dumpDir string) []aibridge.Provider - createRequestFunc createRequestFunc + name string + fixture []byte + newProvider func(addr, dumpDir string) aibridge.Provider + path string }{ { name: "anthropic", fixture: fixtures.AntSimple, - providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} + newProvider: func(addr, dumpDir string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, - createRequestFunc: createAnthropicMessagesReq, + path: pathAnthropicMessages, }, { name: "openai_chat_completions", fixture: fixtures.OaiChatSimple, - providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + newProvider: func(addr, dumpDir string) aibridge.Provider { + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - createRequestFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, }, { name: "openai_responses", fixture: fixtures.OaiResponsesBlockingSimple, - providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + newProvider: func(addr, dumpDir string) aibridge.Provider { + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - createRequestFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, }, } @@ -70,10 +70,11 @@ func TestAPIDump(t *testing.T) { // Create temp dir for API dumps. dumpDir := t.TempDir() - providers := tc.providersFunc(srv.URL, dumpDir) - ts := newBridgeTestServer(t, ctx, providers) + ts := newBridgeTestServer(t, ctx, srv.URL, + withCustomProvider(tc.newProvider(srv.URL, dumpDir)), + ) - req := tc.createRequestFunc(t, ts.URL, fix.Request()) + req := ts.newRequest(t, tc.path, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -148,13 +149,13 @@ func TestAPIDumpPassthrough(t *testing.T) { cases := []struct { name string - providerFunc func(addr string, dumpDir string) aibridge.Provider + newProvider func(addr string, dumpDir string) aibridge.Provider requestPath string expectDumpName string }{ { name: "anthropic", - providerFunc: func(addr string, dumpDir string) aibridge.Provider { + newProvider: func(addr string, dumpDir string) aibridge.Provider { return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, requestPath: "/anthropic/v1/models", @@ -162,7 +163,7 @@ func TestAPIDumpPassthrough(t *testing.T) { }, { name: "openai", - providerFunc: func(addr string, dumpDir string) aibridge.Provider { + newProvider: func(addr string, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, requestPath: "/openai/v1/models", @@ -170,7 +171,7 @@ func TestAPIDumpPassthrough(t *testing.T) { }, { name: "copilot", - providerFunc: func(addr string, dumpDir string) aibridge.Provider { + newProvider: func(addr string, dumpDir string) aibridge.Provider { return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) }, requestPath: "/copilot/models", @@ -193,8 +194,9 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() - prov := tc.providerFunc(upstream.URL, dumpDir) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + ts := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(tc.newProvider(upstream.URL, dumpDir)), + ) req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+tc.requestPath, nil) require.NoError(t, err) diff --git a/internal/integrationtest/bedrock_test.go b/internal/integrationtest/bedrock_test.go index 3709185..62cb112 100644 --- a/internal/integrationtest/bedrock_test.go +++ b/internal/integrationtest/bedrock_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/provider" @@ -18,22 +17,6 @@ import ( "github.com/tidwall/sjson" ) -// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. -func testBedrockCfg(url string) *config.AWSBedrock { - return &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - BaseURL: url, - } -} - -func newBedrockProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) -} - func TestAWSBedrockIntegration(t *testing.T) { t.Parallel() @@ -52,12 +35,12 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)}, + ts := newBridgeTestServer(t, ctx, "http://unused", + withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), withLogger(newLogger(t)), ) - req := createAnthropicMessagesReq(t, ts.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + req := ts.newRequest(t, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -90,8 +73,8 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, + ts := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), withLogger(newLogger(t)), ) @@ -99,7 +82,7 @@ func TestAWSBedrockIntegration(t *testing.T) { // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) diff --git a/internal/integrationtest/bridge.go b/internal/integrationtest/bridge.go index 845c839..e074e36 100644 --- a/internal/integrationtest/bridge.go +++ b/internal/integrationtest/bridge.go @@ -1,24 +1,164 @@ package integrationtest import ( + "bytes" "context" "net" + "net/http" "net/http/httptest" "testing" + "time" "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" + "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" + "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" ) +// Well-known bridged-route paths used by integration tests. +const ( + pathAnthropicMessages = "/anthropic/v1/messages" + pathOpenAIChatCompletions = "/openai/v1/chat/completions" + pathOpenAIResponses = "/openai/v1/responses" +) + +// providerBedrock identifies a Bedrock provider in [withProvider]. There is no +// config-level constant for Bedrock because it re-uses the Anthropic provider +// with an AWS Bedrock configuration. +const providerBedrock = "bedrock" + +// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. +func testBedrockCfg(url string) *config.AWSBedrock { + return &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + BaseURL: url, + } +} + +// newDefaultProvider creates a Provider with default test configuration. +func newDefaultProvider(providerType, addr string) aibridge.Provider { + switch providerType { + case config.ProviderAnthropic: + return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) + case config.ProviderOpenAI: + return provider.NewOpenAI(openAICfg(addr, apiKey)) + case providerBedrock: + return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) + default: + panic("unknown provider type: " + providerType) + } +} + +// withProvider adds a default-configured provider of the given type. +// When any provider option is used, the default "all providers" set is not created. +func withProvider(providerType string) bridgeOption { + return func(c *bridgeConfig) { + c.providerBuilders = append(c.providerBuilders, func(addr string) aibridge.Provider { + return newDefaultProvider(providerType, addr) + }) + } +} + +// withCustomProvider adds a pre-built provider. The upstream URL passed to +// [newBridgeTestServer] is ignored for this provider. +// When any provider option is used, the default "all providers" set is not created. +func withCustomProvider(p aibridge.Provider) bridgeOption { + return func(c *bridgeConfig) { + c.providerBuilders = append(c.providerBuilders, func(string) aibridge.Provider { + return p + }) + } +} + +// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. +// Extra bridge options (e.g. [withProvider]) are appended after the built-in +// MCP / tracer / actor options. When no provider option is given the default +// provider set (all providers) is used. +func setupInjectedToolTest( + t *testing.T, + fixture []byte, + streaming bool, + tracer trace.Tracer, + actorID string, + path string, + toolRequestValidatorFn func(*http.Request, []byte), + opts ...bridgeOption, +) (*testutil.MockRecorder, *mockMCP, *http.Response) { + t.Helper() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixture) + + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + firstResp := newFixtureResponse(fix) + toolResp := newFixtureToolResponse(fix) + toolResp.OnRequest = toolRequestValidatorFn + upstream := newMockUpstream(t, ctx, firstResp, toolResp) + + mockMCP := setupMCPForTest(t, tracer) + + allOpts := []bridgeOption{ + withMCP(mockMCP), + withTracer(tracer), + withActor(actorID, nil), + } + allOpts = append(allOpts, opts...) + ts := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + + req := ts.newRequest(t, path, reqBody) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + // We must ALWAYS have 2 calls to the bridge for injected tool tests. + require.Eventually(t, func() bool { + return upstream.Calls.Load() == 2 + }, time.Second*10, time.Millisecond*50) + + return ts.Recorder, mockMCP, resp +} + +func calculateTotalInputTokens(in []*recorder.TokenUsageRecord) int64 { + var total int64 + for _, el := range in { + total += el.Input + } + return total +} + +func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { + var total int64 + for _, el := range in { + total += el.Output + } + return total +} + // defaultActorID is the actor ID used by default in test servers. const defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" @@ -38,18 +178,29 @@ type bridgeTestServer struct { Bridge *aibridge.RequestBridge } +// newRequest creates a JSON POST request targeting the given path on this server. +func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, s.URL+path, bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + return req +} + // bridgeOption configures a [bridgeTestServer]. type bridgeOption func(*bridgeConfig) type bridgeConfig struct { - metrics *metrics.Metrics - tracer trace.Tracer - mcpProxy mcp.ServerProxier - userID string - metadata recorder.Metadata - logger slog.Logger - loggerSet bool - wrapRecorder bool + providerBuilders []func(upstreamURL string) aibridge.Provider + metrics *metrics.Metrics + tracer trace.Tracer + mcpProxy mcp.ServerProxier + userID string + metadata recorder.Metadata + logger slog.Logger + loggerSet bool + wrapRecorder bool } // withMetrics sets the Prometheus metrics for the bridge. @@ -85,6 +236,7 @@ func withWrappedRecorder() bridgeOption { // newBridgeTestServer creates a fully configured test server running // a RequestBridge with sensible defaults: +// - All standard providers (unless withProvider / withCustomProvider) // - MockRecorder (raw, unless withWrappedRecorder) // - NoopMCPManager (unless withMCP) // - slogtest debug logger (unless withLogger) @@ -93,7 +245,7 @@ func withWrappedRecorder() bridgeOption { func newBridgeTestServer( t *testing.T, ctx context.Context, - providers []aibridge.Provider, + upstreamURL string, opts ...bridgeOption, ) *bridgeTestServer { t.Helper() @@ -114,6 +266,20 @@ func newBridgeTestServer( cfg.mcpProxy = newNoopMCPManager() } + // Resolve providers: use explicit builders when provided, otherwise + // create default providers for every supported type. + var providers []aibridge.Provider + if len(cfg.providerBuilders) > 0 { + for _, b := range cfg.providerBuilders { + providers = append(providers, b(upstreamURL)) + } + } else { + providers = []aibridge.Provider{ + newDefaultProvider(config.ProviderAnthropic, upstreamURL), + newDefaultProvider(config.ProviderOpenAI, upstreamURL), + } + } + mockRec := &testutil.MockRecorder{} var rec aibridge.Recorder = mockRec if cfg.wrapRecorder { diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index cbc1ee8..0b23640 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -19,7 +19,6 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/google/uuid" @@ -29,23 +28,9 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "go.opentelemetry.io/otel/trace" "go.uber.org/goleak" ) -type ( - providerFunc func(addr string) aibridge.Provider - createRequestFunc func(*testing.T, string, []byte) *http.Request -) - -func newAnthropicProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) -} - -func newOpenAIProvider(addr string) aibridge.Provider { - return provider.NewOpenAI(openAICfg(addr, apiKey)) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } @@ -116,9 +101,8 @@ func TestSimple(t *testing.T) { fixture []byte basePath string expectedPath string - providerFn providerFunc getResponseIDFunc func(streaming bool, resp *http.Response) (string, error) - createRequest createRequestFunc + path string expectedMsgID string userAgent string expectedClient aibridge.Client @@ -128,9 +112,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.AntSimple, basePath: "", expectedPath: "/v1/messages", - providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "claude-cli/2.0.67 (external, cli)", expectedClient: aibridge.ClientClaudeCode, @@ -140,9 +123,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.OaiChatSimple, basePath: "", expectedPath: "/chat/completions", - providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", expectedClient: aibridge.ClientCodex, @@ -152,9 +134,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.AntSimple, basePath: "/api", expectedPath: "/api/v1/messages", - providerFn: newAnthropicProvider, getResponseIDFunc: getAnthropicResponseID, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "GitHubCopilotChat/0.37.2026011603", expectedClient: aibridge.ClientCopilotVSC, @@ -164,9 +145,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.OaiChatSimple, basePath: "/api", expectedPath: "/api/chat/completions", - providerFn: newOpenAIProvider, getResponseIDFunc: getOpenAIResponseID, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", expectedClient: aibridge.ClientZed, @@ -187,17 +167,14 @@ func TestSimple(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{tc.providerFn(upstream.URL + tc.basePath)}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) // When: calling the "API server" with the fixture's request body. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := tc.createRequest(t, ts.URL, reqBody) + req := ts.newRequest(t, tc.path, reqBody) req.Header.Set("User-Agent", tc.userAgent) - client := &http.Client{} - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -234,6 +211,7 @@ func TestSimple(t *testing.T) { // Validate user agent and client have been recorded. interceptions := ts.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) + assert.Equal(t, id, interceptions[0].ID) assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) @@ -375,7 +353,7 @@ func TestFallthrough(t *testing.T) { basePath string requestPath string expectedUpstreamPath string - providerFn providerFunc + authHeader string }{ { name: "ant_empty_base_url_path", @@ -384,7 +362,7 @@ func TestFallthrough(t *testing.T) { basePath: "", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/v1/models", - providerFn: newAnthropicProvider, + authHeader: "X-Api-Key", }, { name: "oai_empty_base_url_path", @@ -393,7 +371,7 @@ func TestFallthrough(t *testing.T) { basePath: "", requestPath: "/openai/v1/models", expectedUpstreamPath: "/models", - providerFn: newOpenAIProvider, + authHeader: "Authorization", }, { name: "ant_some_base_url_path", @@ -402,7 +380,7 @@ func TestFallthrough(t *testing.T) { basePath: "/api", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/api/v1/models", - providerFn: newAnthropicProvider, + authHeader: "X-Api-Key", }, { name: "oai_some_base_url_path", @@ -411,7 +389,7 @@ func TestFallthrough(t *testing.T) { basePath: "/api", requestPath: "/openai/v1/models", expectedUpstreamPath: "/api/models", - providerFn: newOpenAIProvider, + authHeader: "Authorization", }, } @@ -421,10 +399,7 @@ func TestFallthrough(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) - p := tc.providerFn(upstream.URL + tc.basePath) - ts := newBridgeTestServer(t, t.Context(), - []aibridge.Provider{p}, - ) + ts := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s%s", ts.URL, tc.requestPath), nil) require.NoError(t, err) @@ -440,7 +415,7 @@ func TestFallthrough(t *testing.T) { received := upstream.receivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(p.AuthHeader()), apiKey) + require.Contains(t, received[0].Header.Get(tc.authHeader), apiKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -455,61 +430,6 @@ func TestFallthrough(t *testing.T) { } } -// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. -func setupInjectedToolTest( - t *testing.T, - fixture []byte, - streaming bool, - providerFn providerFunc, - tracer trace.Tracer, - actorID string, - createRequestFn func(*testing.T, string, []byte) *http.Request, - toolRequestValidatorFn func(*http.Request, []byte), -) (*testutil.MockRecorder, *mockMCP, *http.Response) { - t.Helper() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixture) - - // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. - firstResp := newFixtureResponse(fix) - toolResp := newFixtureToolResponse(fix) - toolResp.OnRequest = toolRequestValidatorFn - upstream := newMockUpstream(t, ctx, firstResp, toolResp) - - mockMCP := setupMCPForTest(t, tracer) - - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{providerFn(upstream.URL)}, - withMCP(mockMCP), - withTracer(tracer), - withActor(actorID, nil), - ) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) - - req := createRequestFn(t, ts.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - t.Cleanup(func() { - _ = resp.Body.Close() - }) - - // We must ALWAYS have 2 calls to the bridge for injected tool tests. - require.Eventually(t, func() bool { - return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) - - return ts.Recorder, mockMCP, resp -} - func TestErrorHandling(t *testing.T) { t.Parallel() @@ -518,15 +438,13 @@ func TestErrorHandling(t *testing.T) { cases := []struct { name string fixture []byte - createRequestFunc createRequestFunc - providerFn providerFunc + path string responseHandlerFn func(resp *http.Response) }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntNonStreamError, - createRequestFunc: createAnthropicMessagesReq, - providerFn: newAnthropicProvider, + name: config.ProviderAnthropic, + fixture: fixtures.AntNonStreamError, + path: pathAnthropicMessages, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -537,10 +455,9 @@ func TestErrorHandling(t *testing.T) { }, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatNonStreamError, - createRequestFunc: createOpenAIChatCompletionsReq, - providerFn: newOpenAIProvider, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatNonStreamError, + path: pathOpenAIChatCompletions, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -568,15 +485,13 @@ func TestErrorHandling(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{tc.providerFn(upstream.URL)}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := tc.createRequestFunc(t, ts.URL, reqBody) + req := ts.newRequest(t, tc.path, reqBody) resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) @@ -594,15 +509,13 @@ func TestErrorHandling(t *testing.T) { cases := []struct { name string fixture []byte - createRequestFunc createRequestFunc - providerFn providerFunc + path string responseHandlerFn func(resp *http.Response) }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntMidStreamError, - createRequestFunc: createAnthropicMessagesReq, - providerFn: newAnthropicProvider, + name: config.ProviderAnthropic, + fixture: fixtures.AntMidStreamError, + path: pathAnthropicMessages, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -614,10 +527,9 @@ func TestErrorHandling(t *testing.T) { }, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatMidStreamError, - createRequestFunc: createOpenAIChatCompletionsReq, - providerFn: newOpenAIProvider, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatMidStreamError, + path: pathOpenAIChatCompletions, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -647,11 +559,9 @@ func TestErrorHandling(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{tc.providerFn(upstream.URL)}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL) - req := tc.createRequestFunc(t, ts.URL, fix.Request()) + req := ts.newRequest(t, tc.path, fix.Request()) resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) @@ -672,22 +582,19 @@ func TestStableRequestEncoding(t *testing.T) { t.Parallel() cases := []struct { - name string - fixture []byte - createRequestFunc createRequestFunc - providerFn providerFunc + name string + fixture []byte + path string }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - createRequestFunc: createAnthropicMessagesReq, - providerFn: newAnthropicProvider, + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + path: pathAnthropicMessages, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - createRequestFunc: createOpenAIChatCompletionsReq, - providerFn: newOpenAIProvider, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + path: pathOpenAIChatCompletions, }, } @@ -711,16 +618,14 @@ func TestStableRequestEncoding(t *testing.T) { } upstream := newMockUpstream(t, ctx, responses...) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{tc.providerFn(upstream.URL)}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP), ) // Make multiple requests and verify they all have identical payloads. for range count { - req := tc.createRequestFunc(t, ts.URL, fix.Request()) - client := &http.Client{} - resp, err := client.Do(req) + req := ts.newRequest(t, tc.path, fix.Request()) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) _ = resp.Body.Close() @@ -743,28 +648,25 @@ func TestEnvironmentDoNotLeak(t *testing.T) { // Test that environment variables containing API keys/tokens are not leaked to upstream requests. // See https://github.com/coder/aibridge/issues/60. testCases := []struct { - name string - fixture []byte - providerFn providerFunc - createRequest createRequestFunc - envVars map[string]string - headerName string + name string + fixture []byte + path string + envVars map[string]string + headerName string }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - providerFn: newAnthropicProvider, - createRequest: createAnthropicMessagesReq, + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + path: pathAnthropicMessages, envVars: map[string]string{ "ANTHROPIC_AUTH_TOKEN": "should-not-leak", }, headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present. }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - providerFn: newOpenAIProvider, - createRequest: createOpenAIChatCompletionsReq, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + path: pathOpenAIChatCompletions, envVars: map[string]string{ "OPENAI_ORG_ID": "should-not-leak", }, @@ -788,13 +690,10 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{tc.providerFn(upstream.URL)}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL) - req := tc.createRequest(t, ts.URL, fix.Request()) - client := &http.Client{} - resp, err := client.Do(req) + req := ts.newRequest(t, tc.path, fix.Request()) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -814,14 +713,14 @@ func TestActorHeaders(t *testing.T) { cases := []struct { name string - createRequest createRequestFunc + path string createProviderFn func(url, key string, sendHeaders bool) aibridge.Provider fixture []byte streaming bool }{ { - name: "openai/v1/chat/completions", - createRequest: createOpenAIChatCompletionsReq, + name: "openai/v1/chat/completions", + path: pathOpenAIChatCompletions, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -832,7 +731,7 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/chat/completions", - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -843,7 +742,7 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: createOpenAIResponsesReq, + path: pathOpenAIResponses, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -854,7 +753,7 @@ func TestActorHeaders(t *testing.T) { }, { name: "openai/v1/responses", - createRequest: createOpenAIResponsesReq, + path: pathOpenAIResponses, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -865,7 +764,7 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := anthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -876,7 +775,7 @@ func TestActorHeaders(t *testing.T) { }, { name: "anthropic/v1/messages", - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := anthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -907,11 +806,9 @@ func TestActorHeaders(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - p := tc.createProviderFn(srv.URL, apiKey, send) - metadataKey := "Username" - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{p}, + ts := newBridgeTestServer(t, ctx, srv.URL, + withCustomProvider(tc.createProviderFn(srv.URL, apiKey, send)), withActor(defaultActorID, recorder.Metadata{ metadataKey: actorUsername, }), @@ -921,9 +818,8 @@ func TestActorHeaders(t *testing.T) { reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) require.NoError(t, err) - req := tc.createRequest(t, ts.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + req := ts.newRequest(t, tc.path, reqBody) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.NotEmpty(t, receivedHeaders) defer resp.Body.Close() @@ -947,19 +843,3 @@ func TestActorHeaders(t *testing.T) { } } } - -func calculateTotalInputTokens(in []*recorder.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Input - } - return total -} - -func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Output - } - return total -} diff --git a/internal/integrationtest/chatcompletions_test.go b/internal/integrationtest/chatcompletions_test.go index f78103d..240d288 100644 --- a/internal/integrationtest/chatcompletions_test.go +++ b/internal/integrationtest/chatcompletions_test.go @@ -11,7 +11,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/provider" "github.com/openai/openai-go/v3" oaissestream "github.com/openai/openai-go/v3/packages/ssestream" "github.com/stretchr/testify/assert" @@ -55,14 +54,12 @@ func TestOpenAIChatCompletions(t *testing.T) { fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -141,15 +138,14 @@ func TestOpenAIChatCompletions(t *testing.T) { // Setup MCP proxies with the tool from the fixture mockMCP := setupMCPForTest(t, defaultTracer) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewOpenAI(openAICfg(upstream.URL, apiKey))}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP), ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -198,7 +194,7 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, defaultTracer, defaultActorID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 014828f..37f7d8c 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" @@ -46,7 +45,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { successBody string requestBody string setupHeaders func(req *http.Request) - createRequest func(t *testing.T, baseURL string, input []byte) *http.Request + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -66,7 +65,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -86,7 +85,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -130,16 +129,16 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: 1, } - prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) makeRequest := func() *http.Response { - req := tc.createRequest(t, ts.URL, []byte(tc.requestBody)) + req := ts.newRequest(t, tc.path, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -216,7 +215,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { errorBody string requestBody string setupHeaders func(req *http.Request) - createRequest func(t *testing.T, baseURL string, input []byte) *http.Request + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -235,7 +234,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -254,7 +253,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -289,16 +288,16 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: 1, } - prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) makeRequest := func() *http.Response { - req := tc.createRequest(t, ts.URL, []byte(tc.requestBody)) + req := ts.newRequest(t, tc.path, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -356,7 +355,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { successBody string requestBody string setupHeaders func(req *http.Request) - createRequest func(t *testing.T, baseURL string, input []byte) *http.Request + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -376,7 +375,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -396,7 +395,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -441,16 +440,16 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: maxRequests, // Allow only 2 concurrent requests in half-open } - prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) makeRequest := func() *http.Response { - req := tc.createRequest(t, ts.URL, []byte(tc.requestBody)) + req := ts.newRequest(t, tc.path, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -561,21 +560,20 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { Timeout: 500 * time.Millisecond, MaxRequests: 1, } - prov := provider.NewAnthropic(config.Anthropic{ - BaseURL: mockUpstream.URL, - Key: "test-key", - CircuitBreaker: cbConfig, - }, nil) - ctx := t.Context() - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(provider.NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil)), withMetrics(m), withActor("test-user-id", nil), ) makeRequest := func(model string) *http.Response { body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - req := createAnthropicMessagesReq(t, ts.URL, []byte(body)) + req := ts.newRequest(t, pathAnthropicMessages, []byte(body)) req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") resp, err := http.DefaultClient.Do(req) diff --git a/internal/integrationtest/messages_test.go b/internal/integrationtest/messages_test.go index 91c3ddd..80c0881 100644 --- a/internal/integrationtest/messages_test.go +++ b/internal/integrationtest/messages_test.go @@ -15,7 +15,6 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/mcp" - "github.com/coder/aibridge/provider" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -58,14 +57,12 @@ func TestAnthropicMessages(t *testing.T) { fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -121,7 +118,7 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, defaultTracer, defaultActorID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathAnthropicMessages, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() @@ -331,8 +328,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { fix := fixtures.Parse(t, fixtures.AntSimple) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mcpMgr), ) @@ -340,7 +336,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -392,9 +388,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { // Create a mock server that captures the request body sent upstream. upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, - []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)}, - ) + ts := newBridgeTestServer(t, ctx, upstream.URL) // Inject adaptive thinking into the fixture request. reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) @@ -402,7 +396,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index 642c7d7..3e0f002 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -12,7 +12,6 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/metrics" - "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" @@ -24,7 +23,7 @@ func TestMetrics_Interception(t *testing.T) { cases := []struct { name string fixture []byte - reqFunc func(*testing.T, string, []byte) *http.Request + path string expectStatus string expectModel string expectRoute string @@ -34,7 +33,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_simple", fixture: fixtures.AntSimple, - reqFunc: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -43,7 +42,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_error", fixture: fixtures.AntNonStreamError, - reqFunc: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -53,7 +52,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_simple", fixture: fixtures.OaiChatSimple, - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -62,7 +61,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_error", fixture: fixtures.OaiChatNonStreamError, - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -72,7 +71,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -81,7 +80,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_error", fixture: fixtures.OaiResponsesBlockingHttpErr, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -91,7 +90,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -100,7 +99,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_error", fixture: fixtures.OaiResponsesStreamingHttpErr, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -121,18 +120,12 @@ func TestMetrics_Interception(t *testing.T) { upstream.AllowOverflow = tc.allowOverflow m := aibridge.NewMetrics(prometheus.NewRegistry()) - var prov aibridge.Provider - if tc.expectProvider == config.ProviderAnthropic { - prov = provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) - } else { - prov = provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - } - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), withWrappedRecorder(), ) - req := tc.reqFunc(t, ts.URL, fix.Request()) + req := ts.newRequest(t, tc.path, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -164,8 +157,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { t.Cleanup(srv.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, srv.URL, withMetrics(m), withWrappedRecorder(), ) @@ -174,7 +166,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := createAnthropicMessagesReq(t, ts.URL, fix.Request()) + req := ts.newRequest(t, pathAnthropicMessages, fix.Request()) resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() @@ -212,8 +204,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { t.Cleanup(upstream.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, t.Context(), upstream.URL, withMetrics(m), withWrappedRecorder(), ) @@ -241,13 +232,12 @@ func TestMetrics_PromptCount(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), withWrappedRecorder(), ) - req := createOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) + req := ts.newRequest(t, pathOpenAIChatCompletions, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -269,13 +259,12 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), withWrappedRecorder(), ) - req := createOpenAIChatCompletionsReq(t, ts.URL, fix.Request()) + req := ts.newRequest(t, pathOpenAIChatCompletions, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -298,17 +287,16 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - prov := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) // Setup mocked MCP server & tools. mockMCP := setupMCPForTest(t, defaultTracer) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), withMCP(mockMCP), ) - req := createAnthropicMessagesReq(t, ts.URL, fix.Request()) + req := ts.newRequest(t, pathAnthropicMessages, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) diff --git a/internal/integrationtest/requests.go b/internal/integrationtest/requests.go index 1ae4679..ef57edf 100644 --- a/internal/integrationtest/requests.go +++ b/internal/integrationtest/requests.go @@ -1,44 +1,12 @@ package integrationtest import ( - "bytes" - "net/http" - "testing" - "github.com/coder/aibridge/config" - "github.com/stretchr/testify/require" ) // apiKey is the default API key used across integration tests. const apiKey = "api-key" -func createJSONReq(t *testing.T, method, baseURL, path string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), method, baseURL+path, bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - return req -} - -// createAnthropicMessagesReq builds an HTTP request targeting the Anthropic messages endpoint. -func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - return createJSONReq(t, http.MethodPost, baseURL, "/anthropic/v1/messages", input) -} - -// createOpenAIChatCompletionsReq builds an HTTP request targeting the OpenAI chat completions endpoint. -func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - return createJSONReq(t, http.MethodPost, baseURL, "/openai/v1/chat/completions", input) -} - -// createOpenAIResponsesReq builds an HTTP request targeting the OpenAI responses endpoint. -func createOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - return createJSONReq(t, http.MethodPost, baseURL, "/openai/v1/responses", input) -} - // openAICfg creates a minimal OpenAI config for testing. func openAICfg(url, key string) config.OpenAI { return config.OpenAI{ diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 84ba5f3..88f277b 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -334,10 +334,9 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, withWrappedRecorder()) + ts := newBridgeTestServer(t, ctx, upstream.URL, withWrappedRecorder()) - req := createOpenAIResponsesReq(t, ts.URL, fix.Request()) + req := ts.newRequest(t, pathOpenAIResponses, fix.Request()) req.Header.Set("User-Agent", tc.userAgent) client := &http.Client{} @@ -426,12 +425,11 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + ts := newBridgeTestServer(t, ctx, upstream.URL) // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) - req := createOpenAIResponsesReq(t, ts.URL, reqBytes) + req := ts.newRequest(t, pathOpenAIResponses, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -539,10 +537,9 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + ts := newBridgeTestServer(t, ctx, upstream.URL) - req := createOpenAIResponsesReq(t, ts.URL, []byte(tc.request)) + req := ts.newRequest(t, pathOpenAIResponses, []byte(tc.request)) client := &http.Client{} resp, err := client.Do(req) @@ -599,11 +596,11 @@ func TestClientAndConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - prov := provider.NewOpenAI(openAICfg(tc.addr, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, withWrappedRecorder()) + // tc.addr may be an intentionally invalid URL; use withCustomProvider. + ts := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))), withWrappedRecorder()) reqBytes := responsesRequestBytes(t, tc.streaming) - req := createOpenAIResponsesReq(t, ts.URL, reqBytes) + req := ts.newRequest(t, pathOpenAIResponses, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -682,11 +679,10 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}) + ts := newBridgeTestServer(t, ctx, upstream.URL) reqBytes := responsesRequestBytes(t, tc.streaming) - req := createOpenAIResponsesReq(t, ts.URL, reqBytes) + req := ts.newRequest(t, pathOpenAIResponses, reqBytes) client := &http.Client{} resp, err := client.Do(req) @@ -865,10 +861,9 @@ func TestResponsesInjectedTool(t *testing.T) { mockMCP.setToolError(tc.mcpToolName, tc.expectToolError) } - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, withMCP(mockMCP)) + ts := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP)) - req := createOpenAIResponsesReq(t, ts.URL, fix.Request()) + req := ts.newRequest(t, pathOpenAIResponses, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index c7b8397..a2e74b2 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -9,10 +9,8 @@ import ( "testing" "time" - "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -100,19 +98,18 @@ func TestTraceAnthropic(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - var bedrockCfg *config.AWSBedrock - if tc.bedrock { - bedrockCfg = testBedrockCfg(upstream.URL) - } - prov := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + opts := []bridgeOption{ withTracer(tracer), withWrappedRecorder(), - ) + } + if tc.bedrock { + opts = append(opts, withProvider(providerBedrock)) + } + ts := newBridgeTestServer(t, ctx, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fixtureReqBody, "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -220,19 +217,18 @@ func TestTraceAnthropicErr(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - var bedrockCfg *config.AWSBedrock - if tc.bedrock { - bedrockCfg = testBedrockCfg(upstream.URL) - } - prov := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + opts := []bridgeOption{ withTracer(tracer), withWrappedRecorder(), - ) + } + if tc.bedrock { + opts = append(opts, withProvider(providerBedrock)) + } + ts := newBridgeTestServer(t, ctx, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, ts.URL, reqBody) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -285,30 +281,25 @@ func TestInjectedToolsTrace(t *testing.T) { streaming bool bedrock bool fixture []byte - providerFn providerFunc - createReqFn func(*testing.T, string, []byte) *http.Request + path string expectModel string - expectPath string expectProvider string + opts []bridgeOption }{ { name: "anthr_blocking", streaming: false, fixture: fixtures.AntSingleInjectedTool, - providerFn: newAnthropicProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "claude-sonnet-4-20250514", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, }, { name: "anthr_streaming", streaming: true, fixture: fixtures.AntSingleInjectedTool, - providerFn: newAnthropicProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "claude-sonnet-4-20250514", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, }, { @@ -316,41 +307,35 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: false, bedrock: true, fixture: fixtures.AntSingleInjectedTool, - providerFn: newBedrockProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "beddel", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, + opts: []bridgeOption{withProvider(providerBedrock)}, }, { name: "bedrock_streaming", streaming: true, bedrock: true, fixture: fixtures.AntSingleInjectedTool, - providerFn: newBedrockProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "beddel", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, + opts: []bridgeOption{withProvider(providerBedrock)}, }, { name: "openai_blocking", streaming: false, fixture: fixtures.OaiChatSingleInjectedTool, - providerFn: newOpenAIProvider, - createReqFn: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectModel: "gpt-4.1", - expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, }, { name: "openai_streaming", streaming: true, fixture: fixtures.OaiChatSingleInjectedTool, - providerFn: newOpenAIProvider, - createReqFn: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectModel: "gpt-4.1", - expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, }, } @@ -372,8 +357,8 @@ func TestInjectedToolsTrace(t *testing.T) { } recorderClient, mockMCP, resp := setupInjectedToolTest( - t, tc.fixture, tc.streaming, tc.providerFn, tracer, defaultActorID, - tc.createReqFn, validatorFn, + t, tc.fixture, tc.streaming, tracer, defaultActorID, + tc.path, validatorFn, tc.opts..., ) defer resp.Body.Close() @@ -383,7 +368,7 @@ func TestInjectedToolsTrace(t *testing.T) { tool := mockMCP.ListTools()[0] attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.RequestPath, tc.path), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, tc.expectProvider), attribute.String(tracing.Model, tc.expectModel), @@ -408,16 +393,15 @@ func TestTraceOpenAI(t *testing.T) { name string fixture []byte streaming bool - expectPath string - reqFunc func(t *testing.T, baseURL string, input []byte) *http.Request + path string + expect []expectTrace }{ { name: "trace_openai_chat_streaming", fixture: fixtures.OaiChatSimple, streaming: true, - expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -432,9 +416,8 @@ func TestTraceOpenAI(t *testing.T) { { name: "trace_openai_chat_blocking", fixture: fixtures.OaiChatSimple, - reqFunc: createOpenAIChatCompletionsReq, streaming: false, - expectPath: "/openai/v1/chat/completions", + path: pathOpenAIChatCompletions, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -450,8 +433,7 @@ func TestTraceOpenAI(t *testing.T) { name: "trace_openai_responses_streaming", fixture: fixtures.OaiResponsesStreamingSimple, streaming: true, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -467,8 +449,7 @@ func TestTraceOpenAI(t *testing.T) { name: "trace_openai_responses_blocking", fixture: fixtures.OaiResponsesBlockingSimple, streaming: false, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -494,15 +475,14 @@ func TestTraceOpenAI(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, upstream.URL, withTracer(tracer), withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.reqFunc(t, ts.URL, reqBody) + req := ts.newRequest(t, tc.path, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -521,7 +501,7 @@ func TestTraceOpenAI(t *testing.T) { require.Len(t, sr.Ended(), totalCount) attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.RequestPath, tc.path), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), @@ -539,8 +519,8 @@ func TestTraceOpenAIErr(t *testing.T) { fixture []byte streaming bool allowOverflow bool - expectPath string - reqFunc func(t *testing.T, baseURL string, input []byte) *http.Request + path string + expect []expectTrace expectCode int }{ @@ -548,8 +528,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_chat_streaming_error", fixture: fixtures.OaiChatMidStreamError, streaming: true, - expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -565,8 +544,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_chat_blocking_error", fixture: fixtures.OaiChatNonStreamError, streaming: false, - expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectCode: http.StatusBadRequest, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -581,8 +559,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_responses_streaming_error", streaming: true, fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -598,8 +575,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_responses_blocking_error", fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, streaming: false, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, // Fixture returns http 200 response with wrong body // responses forward received response as is so // expected code == 200 even though ProcessRequest @@ -620,8 +596,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: true, allowOverflow: true, // 429 error causes retries - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectCode: http.StatusTooManyRequests, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -637,8 +612,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiResponsesBlockingHttpErr, streaming: false, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectCode: http.StatusUnauthorized, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -665,15 +639,14 @@ func TestTraceOpenAIErr(t *testing.T) { mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix)) mockAPI.AllowOverflow = tc.allowOverflow - prov := provider.NewOpenAI(openAICfg(mockAPI.URL, apiKey)) - ts := newBridgeTestServer(t, ctx, []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, ctx, mockAPI.URL, withTracer(tracer), withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.reqFunc(t, ts.URL, reqBody) + req := ts.newRequest(t, tc.path, reqBody) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -693,7 +666,7 @@ func TestTraceOpenAIErr(t *testing.T) { require.Len(t, sr.Ended(), totalCount) attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.RequestPath, tc.path), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), @@ -717,8 +690,7 @@ func TestTracePassthrough(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - prov := provider.NewOpenAI(openAICfg(upstream.URL, apiKey)) - ts := newBridgeTestServer(t, t.Context(), []aibridge.Provider{prov}, + ts := newBridgeTestServer(t, t.Context(), upstream.URL, withTracer(tracer), withWrappedRecorder(), ) From e3d0f14ac7c9363c4197a3adbfcb397bda7ea10d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 10:55:46 +0000 Subject: [PATCH 10/32] merge tests back into bridge_test.go --- internal/integrationtest/bedrock_test.go | 119 --- internal/integrationtest/bridge.go | 31 + internal/integrationtest/bridge_test.go | 790 ++++++++++++++++++ .../integrationtest/chatcompletions_test.go | 316 ------- internal/integrationtest/messages_test.go | 412 --------- .../{upstream.go => mockupstream.go} | 0 internal/integrationtest/requests.go | 36 - 7 files changed, 821 insertions(+), 883 deletions(-) delete mode 100644 internal/integrationtest/bedrock_test.go delete mode 100644 internal/integrationtest/chatcompletions_test.go delete mode 100644 internal/integrationtest/messages_test.go rename internal/integrationtest/{upstream.go => mockupstream.go} (100%) delete mode 100644 internal/integrationtest/requests.go diff --git a/internal/integrationtest/bedrock_test.go b/internal/integrationtest/bedrock_test.go deleted file mode 100644 index 62cb112..0000000 --- a/internal/integrationtest/bedrock_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package integrationtest - -import ( - "context" - "fmt" - "io" - "net/http" - "strings" - "testing" - "time" - - "github.com/coder/aibridge/config" - "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/provider" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -func TestAWSBedrockIntegration(t *testing.T) { - t.Parallel() - - t.Run("invalid config", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Invalid bedrock config - missing region & base url - bedrockCfg := &config.AWSBedrock{ - Region: "", - AccessKey: "test-key", - AccessKeySecret: "test-secret", - Model: "test-model", - SmallFastModel: "test-haiku", - } - - ts := newBridgeTestServer(t, ctx, "http://unused", - withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), - withLogger(newLogger(t)), - ) - - req := ts.newRequest(t, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - require.Equal(t, http.StatusInternalServerError, resp.StatusCode) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "create anthropic client") - require.Contains(t, string(body), "region or base url required") - }) - - t.Run("/v1/messages", func(t *testing.T) { - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. - bedrockCfg := &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "danthropic", // This model should override the request's given one. - SmallFastModel: "danthropic-mini", // Unused but needed for validation. - BaseURL: upstream.URL, // Use the mock server. - } - - ts := newBridgeTestServer(t, ctx, upstream.URL, - withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), - withLogger(newLogger(t)), - ) - - // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. - // We override the AWS Bedrock client to route requests through our mock server. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // For streaming responses, consume the body to allow the stream to complete. - if streaming { - // Read the streaming response. - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - } - - // Verify that Bedrock-specific model name was used in the request to the mock server - // and the interception data. - received := upstream.receivedRequests() - require.Len(t, received, 1) - - // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" - // from the JSON body and encodes them in the URL path. - // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 - pathParts := strings.Split(received[0].Path, "/") - require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) - require.Equal(t, bedrockCfg.Model, pathParts[2]) - require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") - require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - - interceptions := ts.Recorder.RecordedInterceptions() - require.Len(t, interceptions, 1) - require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} diff --git a/internal/integrationtest/bridge.go b/internal/integrationtest/bridge.go index e074e36..0ad4ac7 100644 --- a/internal/integrationtest/bridge.go +++ b/internal/integrationtest/bridge.go @@ -50,6 +50,37 @@ func testBedrockCfg(url string) *config.AWSBedrock { } } +// apiKey is the default API key used across integration tests. +const apiKey = "api-key" + +// openAICfg creates a minimal OpenAI config for testing. +func openAICfg(url, key string) config.OpenAI { + return config.OpenAI{ + BaseURL: url, + Key: key, + } +} + +func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { + cfg := openAICfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + +// anthropicCfg creates a minimal Anthropic config for testing. +func anthropicCfg(url, key string) config.Anthropic { + return config.Anthropic{ + BaseURL: url, + Key: key, + } +} + +func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { + cfg := anthropicCfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + // newDefaultProvider creates a Provider with default test configuration. func newDefaultProvider(providerType, addr string) aibridge.Provider { switch providerType { diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 0b23640..f6271ae 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -15,10 +15,12 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" + "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/google/uuid" @@ -843,3 +845,791 @@ func TestActorHeaders(t *testing.T) { } } } + +func TestAnthropicMessages(t *testing.T) { + t.Parallel() + + t.Run("single builtin tool", func(t *testing.T) { + t.Parallel() + + cases := []struct { + streaming bool + expectedInputTokens int + expectedOutputTokens int + expectedToolCallID string + }{ + { + streaming: true, + expectedInputTokens: 2, + expectedOutputTokens: 66, + expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", + }, + { + streaming: false, + expectedInputTokens: 5, + expectedOutputTokens: 84, + expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", + }, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, upstream.URL) + + // Make API call to aibridge for Anthropic /v1/messages + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + + // Response-specific checks. + if tc.streaming { + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) + + // Ensure the message starts and completes, at a minimum. + assert.Contains(t, sp.AllEvents(), "message_start") + assert.Contains(t, sp.AllEvents(), "message_stop") + } + + expectedTokenRecordings := 1 + if tc.streaming { + // One for message_start, one for message_delta. + expectedTokenRecordings = 2 + } + tokenUsages := ts.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, expectedTokenRecordings) + + assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + + toolUsages := ts.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "Read", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) + require.IsType(t, json.RawMessage{}, toolUsages[0].Args) + var args map[string]any + require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) + require.Contains(t, args, "file_path") + assert.Equal(t, "/tmp/blah/foo", args["file_path"]) + + promptUsages := ts.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "read the foo file", promptUsages[0].Prompt) + + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} + +func TestAnthropicInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathAnthropicMessages, anthropicToolResultValidator(t)) + + // Ensure expected tool was invoked with expected input. + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) + expected, err := json.Marshal(map[string]any{"owner": "admin"}) + require.NoError(t, err) + actual, err := json.Marshal(toolUsages[0].Args) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + + var ( + content *anthropic.ContentBlockUnion + message anthropic.Message + ) + if streaming { + // Parse the response stream. + decoder := ssestream.NewDecoder(resp) + stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) + for stream.Next() { + event := stream.Current() + require.NoError(t, message.Accumulate(event), "accumulate event") + } + + require.NoError(t, stream.Err(), "stream error") + require.Len(t, message.Content, 2) + + content = &message.Content[1] + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + require.GreaterOrEqual(t, len(message.Content), 1) + + content = &message.Content[0] + } + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Text, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // + // We overwrite the final message_delta which is relayed to the client to include the + // accumulated tokens but currently the SDK only supports accumulating output tokens + // for message_delta events. + // + // For non-streaming requests the token usage is also overwritten and should be faithfully + // represented in the response. + // + // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/message.go#L2619-L2622 + if !streaming { + assert.EqualValues(t, 15308, message.Usage.InputTokens) + } + assert.EqualValues(t, 204, message.Usage.OutputTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + tokenUsages := recorderClient.RecordedTokenUsages() + assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) + assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) + + // Ensure we received exactly one prompt. + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} + +// anthropicToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_use and user's tool_result messages +// appended by the inner agentic loop. If the raw payload is not kept in sync with +// the structured messages, the second request will be identical to the first. +func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { + t.Helper() + + return func(_ *http.Request, raw []byte) { + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_use content block + // [N-1] user message with tool_result content block + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_use, and user tool_result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + var hasToolUse bool + for _, block := range assistantMsg.Get("content").Array() { + if block.Get("type").Str == "tool_use" { + hasToolUse = true + break + } + } + require.True(t, hasToolUse, "assistant message must contain a tool_use content block") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "user", toolResultMsg.Get("role").Str, + "last message must be a user message carrying the tool_result") + var hasToolResult bool + for _, block := range toolResultMsg.Get("content").Array() { + if block.Get("type").Str == "tool_result" { + hasToolResult = true + break + } + } + require.True(t, hasToolResult, "user message must contain a tool_result content block") + } +} + +// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is +// correctly disabled based on the tool_choice parameter in the request. +// See https://github.com/coder/aibridge/issues/2 +func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { + t.Parallel() + + var ( + toolChoiceAuto = string(constant.ValueOf[constant.Auto]()) + toolChoiceAny = string(constant.ValueOf[constant.Any]()) + toolChoiceNone = string(constant.ValueOf[constant.None]()) + toolChoiceTool = string(constant.ValueOf[constant.Tool]()) + ) + + cases := []struct { + name string + toolChoice any // nil, or map with "type" key. + withInjectedTools bool + expectDisableParallel bool + expectToolChoiceTypeInRequest string + }{ + // With injected tools - disable_parallel_tool_use should be set. + { + name: "with injected tools: no tool_choice defined defaults to auto", + toolChoice: nil, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected tools: tool_choice auto", + toolChoice: map[string]any{"type": toolChoiceAuto}, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected tools: tool_choice any", + toolChoice: map[string]any{"type": toolChoiceAny}, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + { + name: "with injected tools: tool_choice tool", + toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, + withInjectedTools: true, + expectDisableParallel: true, + expectToolChoiceTypeInRequest: toolChoiceTool, + }, + { + name: "with injected tools: tool_choice none", + toolChoice: map[string]any{"type": toolChoiceNone}, + withInjectedTools: true, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceNone, + }, + // Without injected tools - disable_parallel_tool_use should NOT be set. + { + name: "without injected tools: tool_choice auto", + toolChoice: map[string]any{"type": toolChoiceAuto}, + withInjectedTools: false, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "without injected tools: tool_choice any", + toolChoice: map[string]any{"type": toolChoiceAny}, + withInjectedTools: false, + expectDisableParallel: false, + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup MCP tools conditionally. + var mcpMgr mcp.ServerProxier + if tc.withInjectedTools { + mcpMgr = setupMCPForTest(t, defaultTracer) + } else { + mcpMgr = newNoopMCPManager() + } + + fix := fixtures.Parse(t, fixtures.AntSimple) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mcpMgr), + ) + + // Prepare request body with tool_choice set. + reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) + require.NoError(t, err) + + req := ts.newRequest(t, pathAnthropicMessages, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + // Verify tool_choice in the upstream request. + received := upstream.receivedRequests() + require.Len(t, received, 1) + var receivedRequest map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) + toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) + require.True(t, ok, "expected tool_choice in upstream request") + + // Verify the type matches expectation. + assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"]) + + // Verify name is preserved for tool_choice=tool. + if tc.expectToolChoiceTypeInRequest == toolChoiceTool { + assert.Equal(t, "some_tool", toolChoice["name"]) + } + + // Verify disable_parallel_tool_use based on expectations. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use + disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool) + + if tc.expectDisableParallel { + require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice") + assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true") + } else { + assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set") + } + }) + } +} + +func TestThinkingAdaptiveIsPreserved(t *testing.T) { + t.Parallel() + + fix := fixtures.Parse(t, fixtures.AntSimple) + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Create a mock server that captures the request body sent upstream. + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, upstream.URL) + + // Inject adaptive thinking into the fixture request. + reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) + require.NoError(t, err) + reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) + require.NoError(t, err) + + req := ts.newRequest(t, pathAnthropicMessages, reqBody) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + + // Verify the thinking field was preserved in the upstream request. + received := upstream.receivedRequests() + require.Len(t, received, 1) + assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) + }) + } +} + +func TestOpenAIChatCompletions(t *testing.T) { + t.Parallel() + + t.Run("single builtin tool", func(t *testing.T) { + t.Parallel() + + cases := []struct { + streaming bool + expectedInputTokens, expectedOutputTokens int + expectedToolCallID string + }{ + { + streaming: true, + expectedInputTokens: 60, + expectedOutputTokens: 15, + expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", + }, + { + streaming: false, + expectedInputTokens: 60, + expectedOutputTokens: 15, + expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", + }, + } + + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + ts := newBridgeTestServer(t, ctx, upstream.URL) + + // Make API call to aibridge for OpenAI /v1/chat/completions + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + + // Response-specific checks. + if tc.streaming { + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) + + // OpenAI sends all events under the same type. + messageEvents := sp.MessageEvents() + assert.NotEmpty(t, messageEvents) + + // OpenAI streaming ends with [DONE] + lastEvent := messageEvents[len(messageEvents)-1] + assert.Equal(t, "[DONE]", lastEvent.Data) + } + + tokenUsages := ts.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, 1) + assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + + toolUsages := ts.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "read_file", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) + require.IsType(t, map[string]any{}, toolUsages[0].Args) + require.Contains(t, toolUsages[0].Args, "path") + assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) + + promptUsages := ts.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) + + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) + + t.Run("streaming injected tool call edge cases", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + fixture []byte + expectedArgs map[string]any + }{ + { + name: "tool call no preamble", + fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, + expectedArgs: map[string]any{"owner": "me"}, + }, + { + name: "tool call with non-zero index", + fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, + expectedArgs: nil, // No arguments in this fixture + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) + + // Setup MCP proxies with the tool from the fixture + mockMCP := setupMCPForTest(t, defaultTracer) + + ts := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mockMCP), + ) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) + require.NoError(t, err) + req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify SSE headers are sent correctly + require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + require.Equal(t, "keep-alive", resp.Header.Get("Connection")) + + // Consume the full response body to ensure the interception completes + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + + // Verify the MCP tool was actually invoked + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1, "expected MCP tool to be invoked") + + // Verify tool was invoked with the expected args (if specified) + if tc.expectedArgs != nil { + expected, err := json.Marshal(tc.expectedArgs) + require.NoError(t, err) + actual, err := json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + } + + // Verify tool usage was recorded + toolUsages := ts.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, mockToolName, toolUsages[0].Tool) + + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} + +func TestOpenAIInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) + + // Ensure expected tool was invoked with expected input. + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) + expected, err := json.Marshal(map[string]any{"owner": "admin"}) + require.NoError(t, err) + actual, err := json.Marshal(toolUsages[0].Args) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + + var ( + content *openai.ChatCompletionChoice + message openai.ChatCompletion + ) + if streaming { + // Parse the response stream. + decoder := oaissestream.NewDecoder(resp) + stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) + var acc openai.ChatCompletionAccumulator + detectedToolCalls := make(map[string]struct{}) + for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + + if len(chunk.Choices) == 0 { + continue + } + + for _, c := range chunk.Choices { + if len(c.Delta.ToolCalls) == 0 { + continue + } + + for _, t := range c.Delta.ToolCalls { + if t.Function.Name == "" { + continue + } + + detectedToolCalls[t.Function.Name] = struct{}{} + } + } + } + + // Verify that no injected tool call events (or partials thereof) were sent to the client. + require.Len(t, detectedToolCalls, 0) + + message = acc.ChatCompletion + require.NoError(t, stream.Err(), "stream error") + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + + // Verify that no injected tools were sent to the client. + require.GreaterOrEqual(t, len(message.Choices), 1) + require.Len(t, message.Choices[0].Message.ToolCalls, 0) + } + + require.GreaterOrEqual(t, len(message.Choices), 1) + content = &message.Choices[0] + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. + // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. + // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) + assert.EqualValues(t, 105, message.Usage.CompletionTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + tokenUsages := recorderClient.RecordedTokenUsages() + require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) + require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) + + // Ensure we received exactly one prompt. + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} + +// openaiChatToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_calls and a role=tool result message +// appended by the inner agentic loop. +func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { + t.Helper() + + return func(_ *http.Request, raw []byte) { + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_calls array + // [N-1] message with role=tool + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_calls, and tool result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), + "assistant message must contain a tool_calls array") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "tool", toolResultMsg.Get("role").Str, + "last message must have role=tool") + require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, + "tool result message must have a tool_call_id") + } +} + +func TestAWSBedrockIntegration(t *testing.T) { + t.Parallel() + + t.Run("invalid config", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Invalid bedrock config - missing region & base url + bedrockCfg := &config.AWSBedrock{ + Region: "", + AccessKey: "test-key", + AccessKeySecret: "test-secret", + Model: "test-model", + SmallFastModel: "test-haiku", + } + + ts := newBridgeTestServer(t, ctx, "http://unused", + withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), + withLogger(newLogger(t)), + ) + + req := ts.newRequest(t, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "create anthropic client") + require.Contains(t, string(body), "region or base url required") + }) + + t.Run("/v1/messages", func(t *testing.T) { + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. + bedrockCfg := &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "danthropic", // This model should override the request's given one. + SmallFastModel: "danthropic-mini", // Unused but needed for validation. + BaseURL: upstream.URL, // Use the mock server. + } + + ts := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), + withLogger(newLogger(t)), + ) + + // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. + // We override the AWS Bedrock client to route requests through our mock server. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + req := ts.newRequest(t, pathAnthropicMessages, reqBody) + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // For streaming responses, consume the body to allow the stream to complete. + if streaming { + // Read the streaming response. + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + } + + // Verify that Bedrock-specific model name was used in the request to the mock server + // and the interception data. + received := upstream.receivedRequests() + require.Len(t, received, 1) + + // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" + // from the JSON body and encodes them in the URL path. + // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 + pathParts := strings.Split(received[0].Path, "/") + require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) + require.Equal(t, bedrockCfg.Model, pathParts[2]) + require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") + require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") + + interceptions := ts.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + require.Equal(t, interceptions[0].Model, bedrockCfg.Model) + ts.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} + diff --git a/internal/integrationtest/chatcompletions_test.go b/internal/integrationtest/chatcompletions_test.go deleted file mode 100644 index 240d288..0000000 --- a/internal/integrationtest/chatcompletions_test.go +++ /dev/null @@ -1,316 +0,0 @@ -package integrationtest - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" - "time" - - "github.com/coder/aibridge" - "github.com/coder/aibridge/fixtures" - "github.com/openai/openai-go/v3" - oaissestream "github.com/openai/openai-go/v3/packages/ssestream" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -func TestOpenAIChatCompletions(t *testing.T) { - t.Parallel() - - t.Run("single builtin tool", func(t *testing.T) { - t.Parallel() - - cases := []struct { - streaming bool - expectedInputTokens, expectedOutputTokens int - expectedToolCallID string - }{ - { - streaming: true, - expectedInputTokens: 60, - expectedOutputTokens: 15, - expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", - }, - { - streaming: false, - expectedInputTokens: 60, - expectedOutputTokens: 15, - expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", - }, - } - - for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, upstream.URL) - - // Make API call to aibridge for OpenAI /v1/chat/completions - reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) - require.NoError(t, err) - req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - - // Response-specific checks. - if tc.streaming { - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - - // OpenAI sends all events under the same type. - messageEvents := sp.MessageEvents() - assert.NotEmpty(t, messageEvents) - - // OpenAI streaming ends with [DONE] - lastEvent := messageEvents[len(messageEvents)-1] - assert.Equal(t, "[DONE]", lastEvent.Data) - } - - tokenUsages := ts.Recorder.RecordedTokenUsages() - require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - - toolUsages := ts.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, "read_file", toolUsages[0].Tool) - assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) - require.IsType(t, map[string]any{}, toolUsages[0].Args) - require.Contains(t, toolUsages[0].Args, "path") - assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - - promptUsages := ts.Recorder.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) - - t.Run("streaming injected tool call edge cases", func(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - fixture []byte - expectedArgs map[string]any - }{ - { - name: "tool call no preamble", - fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, - expectedArgs: map[string]any{"owner": "me"}, - }, - { - name: "tool call with non-zero index", - fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, - expectedArgs: nil, // No arguments in this fixture - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) - - // Setup MCP proxies with the tool from the fixture - mockMCP := setupMCPForTest(t, defaultTracer) - - ts := newBridgeTestServer(t, ctx, upstream.URL, - withMCP(mockMCP), - ) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) - require.NoError(t, err) - req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - - // Verify SSE headers are sent correctly - require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) - require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) - require.Equal(t, "keep-alive", resp.Header.Get("Connection")) - - // Consume the full response body to ensure the interception completes - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() - - // Verify the MCP tool was actually invoked - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1, "expected MCP tool to be invoked") - - // Verify tool was invoked with the expected args (if specified) - if tc.expectedArgs != nil { - expected, err := json.Marshal(tc.expectedArgs) - require.NoError(t, err) - actual, err := json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - } - - // Verify tool usage was recorded - toolUsages := ts.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, mockToolName, toolUsages[0].Tool) - - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} - -func TestOpenAIInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *openai.ChatCompletionChoice - message openai.ChatCompletion - ) - if streaming { - // Parse the response stream. - decoder := oaissestream.NewDecoder(resp) - stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) - var acc openai.ChatCompletionAccumulator - detectedToolCalls := make(map[string]struct{}) - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if len(chunk.Choices) == 0 { - continue - } - - for _, c := range chunk.Choices { - if len(c.Delta.ToolCalls) == 0 { - continue - } - - for _, t := range c.Delta.ToolCalls { - if t.Function.Name == "" { - continue - } - - detectedToolCalls[t.Function.Name] = struct{}{} - } - } - } - - // Verify that no injected tool call events (or partials thereof) were sent to the client. - require.Len(t, detectedToolCalls, 0) - - message = acc.ChatCompletion - require.NoError(t, stream.Err(), "stream error") - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - - // Verify that no injected tools were sent to the client. - require.GreaterOrEqual(t, len(message.Choices), 1) - require.Len(t, message.Choices[0].Message.ToolCalls, 0) - } - - require.GreaterOrEqual(t, len(message.Choices), 1) - content = &message.Choices[0] - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. - - // Check the token usage from the client's perspective. - // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. - // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. - // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) - assert.EqualValues(t, 105, message.Usage.CompletionTokens) - - // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) - require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) - - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} - -// openaiChatToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_calls and a role=tool result message -// appended by the inner agentic loop. -func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { - t.Helper() - - return func(_ *http.Request, raw []byte) { - messages := gjson.GetBytes(raw, "messages").Array() - - // After the agentic loop the messages must contain at minimum: - // [0] original user message - // [N-2] assistant message with tool_calls array - // [N-1] message with role=tool - require.GreaterOrEqual(t, len(messages), 3, - "second upstream request must contain the original message, assistant tool_calls, and tool result") - - assistantMsg := messages[len(messages)-2] - require.Equal(t, "assistant", assistantMsg.Get("role").Str, - "penultimate message must be from the assistant") - require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), - "assistant message must contain a tool_calls array") - - toolResultMsg := messages[len(messages)-1] - require.Equal(t, "tool", toolResultMsg.Get("role").Str, - "last message must have role=tool") - require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, - "tool result message must have a tool_call_id") - } -} diff --git a/internal/integrationtest/messages_test.go b/internal/integrationtest/messages_test.go deleted file mode 100644 index 80c0881..0000000 --- a/internal/integrationtest/messages_test.go +++ /dev/null @@ -1,412 +0,0 @@ -package integrationtest - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" - "time" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/packages/ssestream" - "github.com/anthropics/anthropic-sdk-go/shared/constant" - "github.com/coder/aibridge" - "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/mcp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -func TestAnthropicMessages(t *testing.T) { - t.Parallel() - - t.Run("single builtin tool", func(t *testing.T) { - t.Parallel() - - cases := []struct { - streaming bool - expectedInputTokens int - expectedOutputTokens int - expectedToolCallID string - }{ - { - streaming: true, - expectedInputTokens: 2, - expectedOutputTokens: 66, - expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", - }, - { - streaming: false, - expectedInputTokens: 5, - expectedOutputTokens: 84, - expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", - }, - } - - for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, upstream.URL) - - // Make API call to aibridge for Anthropic /v1/messages - reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) - require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - - // Response-specific checks. - if tc.streaming { - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - - // Ensure the message starts and completes, at a minimum. - assert.Contains(t, sp.AllEvents(), "message_start") - assert.Contains(t, sp.AllEvents(), "message_stop") - } - - expectedTokenRecordings := 1 - if tc.streaming { - // One for message_start, one for message_delta. - expectedTokenRecordings = 2 - } - tokenUsages := ts.Recorder.RecordedTokenUsages() - require.Len(t, tokenUsages, expectedTokenRecordings) - - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") - - toolUsages := ts.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, "Read", toolUsages[0].Tool) - assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) - require.IsType(t, json.RawMessage{}, toolUsages[0].Args) - var args map[string]any - require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) - require.Contains(t, args, "file_path") - assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - - promptUsages := ts.Recorder.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - - ts.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} - -func TestAnthropicInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathAnthropicMessages, anthropicToolResultValidator(t)) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *anthropic.ContentBlockUnion - message anthropic.Message - ) - if streaming { - // Parse the response stream. - decoder := ssestream.NewDecoder(resp) - stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) - for stream.Next() { - event := stream.Current() - require.NoError(t, message.Accumulate(event), "accumulate event") - } - - require.NoError(t, stream.Err(), "stream error") - require.Len(t, message.Content, 2) - - content = &message.Content[1] - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - require.GreaterOrEqual(t, len(message.Content), 1) - - content = &message.Content[0] - } - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Text, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. - - // Check the token usage from the client's perspective. - // - // We overwrite the final message_delta which is relayed to the client to include the - // accumulated tokens but currently the SDK only supports accumulating output tokens - // for message_delta events. - // - // For non-streaming requests the token usage is also overwritten and should be faithfully - // represented in the response. - // - // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/message.go#L2619-L2622 - if !streaming { - assert.EqualValues(t, 15308, message.Usage.InputTokens) - } - assert.EqualValues(t, 204, message.Usage.OutputTokens) - - // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) - assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) - - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} - -// anthropicToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_use and user's tool_result messages -// appended by the inner agentic loop. If the raw payload is not kept in sync with -// the structured messages, the second request will be identical to the first. -func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { - t.Helper() - - return func(_ *http.Request, raw []byte) { - messages := gjson.GetBytes(raw, "messages").Array() - - // After the agentic loop the messages must contain at minimum: - // [0] original user message - // [N-2] assistant message with tool_use content block - // [N-1] user message with tool_result content block - require.GreaterOrEqual(t, len(messages), 3, - "second upstream request must contain the original message, assistant tool_use, and user tool_result") - - assistantMsg := messages[len(messages)-2] - require.Equal(t, "assistant", assistantMsg.Get("role").Str, - "penultimate message must be from the assistant") - var hasToolUse bool - for _, block := range assistantMsg.Get("content").Array() { - if block.Get("type").Str == "tool_use" { - hasToolUse = true - break - } - } - require.True(t, hasToolUse, "assistant message must contain a tool_use content block") - - toolResultMsg := messages[len(messages)-1] - require.Equal(t, "user", toolResultMsg.Get("role").Str, - "last message must be a user message carrying the tool_result") - var hasToolResult bool - for _, block := range toolResultMsg.Get("content").Array() { - if block.Get("type").Str == "tool_result" { - hasToolResult = true - break - } - } - require.True(t, hasToolResult, "user message must contain a tool_result content block") - } -} - -// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is -// correctly disabled based on the tool_choice parameter in the request. -// See https://github.com/coder/aibridge/issues/2 -func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { - t.Parallel() - - var ( - toolChoiceAuto = string(constant.ValueOf[constant.Auto]()) - toolChoiceAny = string(constant.ValueOf[constant.Any]()) - toolChoiceNone = string(constant.ValueOf[constant.None]()) - toolChoiceTool = string(constant.ValueOf[constant.Tool]()) - ) - - cases := []struct { - name string - toolChoice any // nil, or map with "type" key. - withInjectedTools bool - expectDisableParallel bool - expectToolChoiceTypeInRequest string - }{ - // With injected tools - disable_parallel_tool_use should be set. - { - name: "with injected tools: no tool_choice defined defaults to auto", - toolChoice: nil, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceAuto, - }, - { - name: "with injected tools: tool_choice auto", - toolChoice: map[string]any{"type": toolChoiceAuto}, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceAuto, - }, - { - name: "with injected tools: tool_choice any", - toolChoice: map[string]any{"type": toolChoiceAny}, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceAny, - }, - { - name: "with injected tools: tool_choice tool", - toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, - withInjectedTools: true, - expectDisableParallel: true, - expectToolChoiceTypeInRequest: toolChoiceTool, - }, - { - name: "with injected tools: tool_choice none", - toolChoice: map[string]any{"type": toolChoiceNone}, - withInjectedTools: true, - expectDisableParallel: false, - expectToolChoiceTypeInRequest: toolChoiceNone, - }, - // Without injected tools - disable_parallel_tool_use should NOT be set. - { - name: "without injected tools: tool_choice auto", - toolChoice: map[string]any{"type": toolChoiceAuto}, - withInjectedTools: false, - expectDisableParallel: false, - expectToolChoiceTypeInRequest: toolChoiceAuto, - }, - { - name: "without injected tools: tool_choice any", - toolChoice: map[string]any{"type": toolChoiceAny}, - withInjectedTools: false, - expectDisableParallel: false, - expectToolChoiceTypeInRequest: toolChoiceAny, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Setup MCP tools conditionally. - var mcpMgr mcp.ServerProxier - if tc.withInjectedTools { - mcpMgr = setupMCPForTest(t, defaultTracer) - } else { - mcpMgr = newNoopMCPManager() - } - - fix := fixtures.Parse(t, fixtures.AntSimple) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, upstream.URL, - withMCP(mcpMgr), - ) - - // Prepare request body with tool_choice set. - reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) - require.NoError(t, err) - - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() - - // Verify tool_choice in the upstream request. - received := upstream.receivedRequests() - require.Len(t, received, 1) - var receivedRequest map[string]any - require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) - toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) - require.True(t, ok, "expected tool_choice in upstream request") - - // Verify the type matches expectation. - assert.Equal(t, tc.expectToolChoiceTypeInRequest, toolChoice["type"]) - - // Verify name is preserved for tool_choice=tool. - if tc.expectToolChoiceTypeInRequest == toolChoiceTool { - assert.Equal(t, "some_tool", toolChoice["name"]) - } - - // Verify disable_parallel_tool_use based on expectations. - // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use - disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool) - - if tc.expectDisableParallel { - require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice") - assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true") - } else { - assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set") - } - }) - } -} - -func TestThinkingAdaptiveIsPreserved(t *testing.T) { - t.Parallel() - - fix := fixtures.Parse(t, fixtures.AntSimple) - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Create a mock server that captures the request body sent upstream. - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - ts := newBridgeTestServer(t, ctx, upstream.URL) - - // Inject adaptive thinking into the fixture request. - reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) - require.NoError(t, err) - reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) - require.NoError(t, err) - - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) - _ = resp.Body.Close() - - // Verify the thinking field was preserved in the upstream request. - received := upstream.receivedRequests() - require.Len(t, received, 1) - assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) - }) - } -} diff --git a/internal/integrationtest/upstream.go b/internal/integrationtest/mockupstream.go similarity index 100% rename from internal/integrationtest/upstream.go rename to internal/integrationtest/mockupstream.go diff --git a/internal/integrationtest/requests.go b/internal/integrationtest/requests.go deleted file mode 100644 index ef57edf..0000000 --- a/internal/integrationtest/requests.go +++ /dev/null @@ -1,36 +0,0 @@ -package integrationtest - -import ( - "github.com/coder/aibridge/config" -) - -// apiKey is the default API key used across integration tests. -const apiKey = "api-key" - -// openAICfg creates a minimal OpenAI config for testing. -func openAICfg(url, key string) config.OpenAI { - return config.OpenAI{ - BaseURL: url, - Key: key, - } -} - -func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { - cfg := openAICfg(url, key) - cfg.APIDumpDir = dumpDir - return cfg -} - -// anthropicCfg creates a minimal Anthropic config for testing. -func anthropicCfg(url, key string) config.Anthropic { - return config.Anthropic{ - BaseURL: url, - Key: key, - } -} - -func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { - cfg := anthropicCfg(url, key) - cfg.APIDumpDir = dumpDir - return cfg -} From 741133c9478c2d918f8e94e9ecd3436678c6005b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 11:27:46 +0000 Subject: [PATCH 11/32] bridge.go -> setupbridge.go + helpers.go --- internal/integrationtest/bridge_test.go | 18 +- internal/integrationtest/helpers.go | 55 +++++ internal/integrationtest/mockmcp.go | 14 +- internal/integrationtest/mockupstream.go | 56 ++--- .../{bridge.go => setupbridge.go} | 199 ++++++------------ internal/testutil/mock_recorder.go | 17 ++ 6 files changed, 181 insertions(+), 178 deletions(-) create mode 100644 internal/integrationtest/helpers.go rename internal/integrationtest/{bridge.go => setupbridge.go} (79%) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index f6271ae..d2f7171 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -912,8 +912,8 @@ func TestAnthropicMessages(t *testing.T) { tokenUsages := ts.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, expectedTokenRecordings) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, ts.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, ts.Recorder.TotalOutputTokens(), "output tokens miscalculated") toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) @@ -1008,9 +1008,8 @@ func TestAnthropicInjectedTools(t *testing.T) { assert.EqualValues(t, 204, message.Usage.OutputTokens) // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) - assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) + assert.EqualValues(t, 15308, recorderClient.TotalInputTokens()) + assert.EqualValues(t, 204, recorderClient.TotalOutputTokens()) // Ensure we received exactly one prompt. promptUsages := recorderClient.RecordedPromptUsages() @@ -1300,8 +1299,8 @@ func TestOpenAIChatCompletions(t *testing.T) { tokenUsages := ts.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, ts.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, ts.Recorder.TotalOutputTokens(), "output tokens miscalculated") toolUsages := ts.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) @@ -1491,9 +1490,8 @@ func TestOpenAIInjectedTools(t *testing.T) { assert.EqualValues(t, 105, message.Usage.CompletionTokens) // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) - require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) + require.EqualValues(t, 5047, recorderClient.TotalInputTokens()) + require.EqualValues(t, 105, recorderClient.TotalOutputTokens()) // Ensure we received exactly one prompt. promptUsages := recorderClient.RecordedPromptUsages() diff --git a/internal/integrationtest/helpers.go b/internal/integrationtest/helpers.go new file mode 100644 index 0000000..e604f31 --- /dev/null +++ b/internal/integrationtest/helpers.go @@ -0,0 +1,55 @@ +package integrationtest + +import ( + "testing" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/config" +) + +// anthropicCfg creates a minimal Anthropic config for testing. +func anthropicCfg(url, key string) config.Anthropic { + return config.Anthropic{ + BaseURL: url, + Key: key, + } +} + +func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { + cfg := anthropicCfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + +// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. +func testBedrockCfg(url string) *config.AWSBedrock { + return &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + BaseURL: url, + } +} + +// openAICfg creates a minimal OpenAI config for testing. +func openAICfg(url, key string) config.OpenAI { + return config.OpenAI{ + BaseURL: url, + Key: key, + } +} + +func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { + cfg := openAICfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + +// newLogger creates a test logger at Debug level. +func newLogger(t *testing.T) slog.Logger { + t.Helper() + return slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) +} diff --git a/internal/integrationtest/mockmcp.go b/internal/integrationtest/mockmcp.go index eba25dd..e2f3dcb 100644 --- a/internal/integrationtest/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -31,6 +31,13 @@ type mockMCP struct { calls *callAccumulator } +// callAccumulator tracks all tool invocations by name and each instance's arguments. +type callAccumulator struct { + calls map[string][]any + callsMu sync.Mutex + toolErrors map[string]string +} + // getCallsByTool returns recorded arguments for a given tool name. func (m *mockMCP) getCallsByTool(name string) []any { return m.calls.getCallsByTool(name) @@ -84,13 +91,6 @@ func newNoopMCPManager() mcp.ServerProxier { return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer("")) } -// callAccumulator tracks all tool invocations by name and each instance's arguments. -type callAccumulator struct { - calls map[string][]any - callsMu sync.Mutex - toolErrors map[string]string -} - func newCallAccumulator() *callAccumulator { return &callAccumulator{ calls: make(map[string][]any), diff --git a/internal/integrationtest/mockupstream.go b/internal/integrationtest/mockupstream.go index a658b05..c0dbec0 100644 --- a/internal/integrationtest/mockupstream.go +++ b/internal/integrationtest/mockupstream.go @@ -35,34 +35,6 @@ type upstreamResponse struct { OnRequest func(r *http.Request, body []byte) } -// newFixtureResponse creates an upstreamResponse from a parsed fixture archive. -// It reads whichever of 'streaming' and 'non-streaming' sections exist; -// not every fixture has both (e.g. error fixtures may only define one). -func newFixtureResponse(fix fixtures.Fixture) upstreamResponse { - var resp upstreamResponse - if fix.Has(fixtures.SectionStreaming) { - resp.Streaming = fix.Streaming() - } - if fix.Has(fixtures.SectionNonStreaming) { - resp.Blocking = fix.NonStreaming() - } - return resp -} - -// newFixtureToolResponse creates an upstreamResponse from the tool-call fixture files. -// It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' -// sections exist. -func newFixtureToolResponse(fix fixtures.Fixture) upstreamResponse { - var resp upstreamResponse - if fix.Has(fixtures.SectionStreamingToolCall) { - resp.Streaming = fix.StreamingToolCall() - } - if fix.Has(fixtures.SectionNonStreamToolCall) { - resp.Blocking = fix.NonStreamingToolCall() - } - return resp -} - // receivedRequest captures the details of a single request handled by mockUpstream. type receivedRequest struct { Method string @@ -98,6 +70,34 @@ type mockUpstream struct { responses []upstreamResponse } +// newFixtureResponse creates an upstreamResponse from a parsed fixture archive. +// It reads whichever of 'streaming' and 'non-streaming' sections exist; +// not every fixture has both (e.g. error fixtures may only define one). +func newFixtureResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse + if fix.Has(fixtures.SectionStreaming) { + resp.Streaming = fix.Streaming() + } + if fix.Has(fixtures.SectionNonStreaming) { + resp.Blocking = fix.NonStreaming() + } + return resp +} + +// newFixtureToolResponse creates an upstreamResponse from the tool-call fixture files. +// It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' +// sections exist. +func newFixtureToolResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse + if fix.Has(fixtures.SectionStreamingToolCall) { + resp.Streaming = fix.StreamingToolCall() + } + if fix.Has(fixtures.SectionNonStreamToolCall) { + resp.Blocking = fix.NonStreamingToolCall() + } + return resp +} + // receivedRequests returns a copy of all requests received so far. func (ms *mockUpstream) receivedRequests() []receivedRequest { ms.mu.Lock() diff --git a/internal/integrationtest/bridge.go b/internal/integrationtest/setupbridge.go similarity index 79% rename from internal/integrationtest/bridge.go rename to internal/integrationtest/setupbridge.go index 0ad4ac7..b2b7301 100644 --- a/internal/integrationtest/bridge.go +++ b/internal/integrationtest/setupbridge.go @@ -10,7 +10,6 @@ import ( "time" "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" @@ -26,59 +25,53 @@ import ( "go.opentelemetry.io/otel/trace" ) -// Well-known bridged-route paths used by integration tests. const ( + // Well-known bridged-route paths used by integration tests. pathAnthropicMessages = "/anthropic/v1/messages" pathOpenAIChatCompletions = "/openai/v1/chat/completions" pathOpenAIResponses = "/openai/v1/responses" -) -// providerBedrock identifies a Bedrock provider in [withProvider]. There is no -// config-level constant for Bedrock because it re-uses the Anthropic provider -// with an AWS Bedrock configuration. -const providerBedrock = "bedrock" - -// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. -func testBedrockCfg(url string) *config.AWSBedrock { - return &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - BaseURL: url, - } -} + // providerBedrock identifies a Bedrock provider in [withProvider]. + providerBedrock = "bedrock" -// apiKey is the default API key used across integration tests. -const apiKey = "api-key" + // defaults + apiKey = "api-key" + defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" +) -// openAICfg creates a minimal OpenAI config for testing. -func openAICfg(url, key string) config.OpenAI { - return config.OpenAI{ - BaseURL: url, - Key: key, - } -} +// defaultTracer is the default OTel tracer used in integration tests. +var defaultTracer = otel.Tracer("integrationtest") -func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { - cfg := openAICfg(url, key) - cfg.APIDumpDir = dumpDir - return cfg +// bridgeTestServer wraps an httptest.Server running a RequestBridge. +type bridgeTestServer struct { + *httptest.Server + Recorder *testutil.MockRecorder + Bridge *aibridge.RequestBridge } -// anthropicCfg creates a minimal Anthropic config for testing. -func anthropicCfg(url, key string) config.Anthropic { - return config.Anthropic{ - BaseURL: url, - Key: key, - } +// bridgeOption configures a [bridgeTestServer]. +type bridgeOption func(*bridgeConfig) + +type bridgeConfig struct { + providerBuilders []func(upstreamURL string) aibridge.Provider + metrics *metrics.Metrics + tracer trace.Tracer + mcpProxy mcp.ServerProxier + userID string + metadata recorder.Metadata + logger slog.Logger + loggerSet bool + wrapRecorder bool } -func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { - cfg := anthropicCfg(url, key) - cfg.APIDumpDir = dumpDir - return cfg +// newRequest creates a JSON POST request targeting the given path on this server. +func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, s.URL+path, bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + return req } // newDefaultProvider creates a Provider with default test configuration. @@ -116,6 +109,37 @@ func withCustomProvider(p aibridge.Provider) bridgeOption { } } +// withMetrics sets the Prometheus metrics for the bridge. +func withMetrics(m *metrics.Metrics) bridgeOption { + return func(c *bridgeConfig) { c.metrics = m } +} + +// withTracer overrides the default tracer. +func withTracer(t trace.Tracer) bridgeOption { + return func(c *bridgeConfig) { c.tracer = t } +} + +// withMCP sets the MCP server proxier (default: NoopMCPManager). +func withMCP(p mcp.ServerProxier) bridgeOption { + return func(c *bridgeConfig) { c.mcpProxy = p } +} + +// withActor sets the actor ID and metadata for the BaseContext. +func withActor(id string, md recorder.Metadata) bridgeOption { + return func(c *bridgeConfig) { c.userID = id; c.metadata = md } +} + +// withLogger overrides the default slogtest debug logger. +func withLogger(l slog.Logger) bridgeOption { + return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } +} + +// withWrappedRecorder wraps the MockRecorder through aibridge.NewRecorder +// (the production recorder wrapper). Use when testing the recorder pipeline. +func withWrappedRecorder() bridgeOption { + return func(c *bridgeConfig) { c.wrapRecorder = true } +} + // setupInjectedToolTest abstracts common setup required for injected-tool integration tests. // Extra bridge options (e.g. [withProvider]) are appended after the built-in // MCP / tracer / actor options. When no provider option is given the default @@ -174,97 +198,6 @@ func setupInjectedToolTest( return ts.Recorder, mockMCP, resp } -func calculateTotalInputTokens(in []*recorder.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Input - } - return total -} - -func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Output - } - return total -} - -// defaultActorID is the actor ID used by default in test servers. -const defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" - -// defaultTracer is the default OTel tracer used in integration tests. -var defaultTracer = otel.Tracer("integrationtest") - -// newLogger creates a test logger at Debug level. -func newLogger(t *testing.T) slog.Logger { - t.Helper() - return slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) -} - -// bridgeTestServer wraps an httptest.Server running a RequestBridge. -type bridgeTestServer struct { - *httptest.Server - Recorder *testutil.MockRecorder - Bridge *aibridge.RequestBridge -} - -// newRequest creates a JSON POST request targeting the given path on this server. -func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, s.URL+path, bytes.NewReader(body)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - return req -} - -// bridgeOption configures a [bridgeTestServer]. -type bridgeOption func(*bridgeConfig) - -type bridgeConfig struct { - providerBuilders []func(upstreamURL string) aibridge.Provider - metrics *metrics.Metrics - tracer trace.Tracer - mcpProxy mcp.ServerProxier - userID string - metadata recorder.Metadata - logger slog.Logger - loggerSet bool - wrapRecorder bool -} - -// withMetrics sets the Prometheus metrics for the bridge. -func withMetrics(m *metrics.Metrics) bridgeOption { - return func(c *bridgeConfig) { c.metrics = m } -} - -// withTracer overrides the default tracer. -func withTracer(t trace.Tracer) bridgeOption { - return func(c *bridgeConfig) { c.tracer = t } -} - -// withMCP sets the MCP server proxier (default: NoopMCPManager). -func withMCP(p mcp.ServerProxier) bridgeOption { - return func(c *bridgeConfig) { c.mcpProxy = p } -} - -// withActor sets the actor ID and metadata for the BaseContext. -func withActor(id string, md recorder.Metadata) bridgeOption { - return func(c *bridgeConfig) { c.userID = id; c.metadata = md } -} - -// withLogger overrides the default slogtest debug logger. -func withLogger(l slog.Logger) bridgeOption { - return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } -} - -// withWrappedRecorder wraps the MockRecorder through aibridge.NewRecorder -// (the production recorder wrapper). Use when testing the recorder pipeline. -func withWrappedRecorder() bridgeOption { - return func(c *bridgeConfig) { c.wrapRecorder = true } -} - // newBridgeTestServer creates a fully configured test server running // a RequestBridge with sensible defaults: // - All standard providers (unless withProvider / withCustomProvider) diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index ac39006..a256a64 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -73,6 +73,23 @@ func (m *MockRecorder) RecordedTokenUsages() []*recorder.TokenUsageRecord { defer m.mu.Unlock() return slices.Clone(m.tokenUsages) } +// TotalInputTokens returns the sum of input tokens across all recorded token usages. +func (m *MockRecorder) TotalInputTokens() int64 { + var total int64 + for _, el := range m.RecordedTokenUsages() { + total += el.Input + } + return total +} + +// TotalOutputTokens returns the sum of output tokens across all recorded token usages. +func (m *MockRecorder) TotalOutputTokens() int64 { + var total int64 + for _, el := range m.RecordedTokenUsages() { + total += el.Output + } + return total +} // RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner. // Note: This is a shallow clone (see RecordedTokenUsages for details). From c0541dfbe9b762b1439f5a2f9ad7a4cf5adedeb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 12:20:32 +0000 Subject: [PATCH 12/32] remove withWrappedRecorder --- internal/integrationtest/helpers.go | 12 ++-- internal/integrationtest/metrics_test.go | 23 +++---- internal/integrationtest/responses_test.go | 4 +- internal/integrationtest/setupbridge.go | 24 ++------ internal/integrationtest/trace_test.go | 71 ++++++++++------------ 5 files changed, 56 insertions(+), 78 deletions(-) diff --git a/internal/integrationtest/helpers.go b/internal/integrationtest/helpers.go index e604f31..84bd64d 100644 --- a/internal/integrationtest/helpers.go +++ b/internal/integrationtest/helpers.go @@ -9,21 +9,21 @@ import ( ) // anthropicCfg creates a minimal Anthropic config for testing. -func anthropicCfg(url, key string) config.Anthropic { +func anthropicCfg(url string, key string) config.Anthropic { return config.Anthropic{ BaseURL: url, Key: key, } } -func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { +func anthropicCfgWithAPIDump(url string, key string, dumpDir string) config.Anthropic { cfg := anthropicCfg(url, key) cfg.APIDumpDir = dumpDir return cfg } -// testBedrockCfg returns a test AWS Bedrock config pointing at the given URL. -func testBedrockCfg(url string) *config.AWSBedrock { +// bedrockCfg returns a test AWS Bedrock config pointing at the given URL. +func bedrockCfg(url string) *config.AWSBedrock { return &config.AWSBedrock{ Region: "us-west-2", AccessKey: "test-access-key", @@ -35,14 +35,14 @@ func testBedrockCfg(url string) *config.AWSBedrock { } // openAICfg creates a minimal OpenAI config for testing. -func openAICfg(url, key string) config.OpenAI { +func openAICfg(url string, key string) config.OpenAI { return config.OpenAI{ BaseURL: url, Key: key, } } -func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { +func openaiCfgWithAPIDump(url string, key string, dumpDir string) config.OpenAI { cfg := openAICfg(url, key) cfg.APIDumpDir = dumpDir return cfg diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index 3e0f002..dffa15c 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -23,7 +23,7 @@ func TestMetrics_Interception(t *testing.T) { cases := []struct { name string fixture []byte - path string + path string expectStatus string expectModel string expectRoute string @@ -33,7 +33,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_simple", fixture: fixtures.AntSimple, - path: pathAnthropicMessages, + path: pathAnthropicMessages, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -42,7 +42,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_error", fixture: fixtures.AntNonStreamError, - path: pathAnthropicMessages, + path: pathAnthropicMessages, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -52,7 +52,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_simple", fixture: fixtures.OaiChatSimple, - path: pathOpenAIChatCompletions, + path: pathOpenAIChatCompletions, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -61,7 +61,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_error", fixture: fixtures.OaiChatNonStreamError, - path: pathOpenAIChatCompletions, + path: pathOpenAIChatCompletions, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -71,7 +71,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -80,7 +80,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_error", fixture: fixtures.OaiResponsesBlockingHttpErr, - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -90,7 +90,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -99,7 +99,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_error", fixture: fixtures.OaiResponsesStreamingHttpErr, - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -122,7 +122,6 @@ func TestMetrics_Interception(t *testing.T) { m := aibridge.NewMetrics(prometheus.NewRegistry()) ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), - withWrappedRecorder(), ) req := ts.newRequest(t, tc.path, fix.Request()) @@ -159,7 +158,6 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { m := aibridge.NewMetrics(prometheus.NewRegistry()) ts := newBridgeTestServer(t, ctx, srv.URL, withMetrics(m), - withWrappedRecorder(), ) // Make request in background. @@ -206,7 +204,6 @@ func TestMetrics_PassthroughCount(t *testing.T) { m := aibridge.NewMetrics(prometheus.NewRegistry()) ts := newBridgeTestServer(t, t.Context(), upstream.URL, withMetrics(m), - withWrappedRecorder(), ) req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) @@ -234,7 +231,6 @@ func TestMetrics_PromptCount(t *testing.T) { m := aibridge.NewMetrics(prometheus.NewRegistry()) ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), - withWrappedRecorder(), ) req := ts.newRequest(t, pathOpenAIChatCompletions, fix.Request()) @@ -261,7 +257,6 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { m := aibridge.NewMetrics(prometheus.NewRegistry()) ts := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), - withWrappedRecorder(), ) req := ts.newRequest(t, pathOpenAIChatCompletions, fix.Request()) diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 88f277b..91a978c 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -334,7 +334,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL, withWrappedRecorder()) + ts := newBridgeTestServer(t, ctx, upstream.URL) req := ts.newRequest(t, pathOpenAIResponses, fix.Request()) req.Header.Set("User-Agent", tc.userAgent) @@ -597,7 +597,7 @@ func TestClientAndConnectionError(t *testing.T) { t.Cleanup(cancel) // tc.addr may be an intentionally invalid URL; use withCustomProvider. - ts := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))), withWrappedRecorder()) + ts := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) reqBytes := responsesRequestBytes(t, tc.streaming) req := ts.newRequest(t, pathOpenAIResponses, reqBytes) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index b2b7301..d371804 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -26,12 +26,12 @@ import ( ) const ( - // Well-known bridged-route paths used by integration tests. pathAnthropicMessages = "/anthropic/v1/messages" pathOpenAIChatCompletions = "/openai/v1/chat/completions" pathOpenAIResponses = "/openai/v1/responses" // providerBedrock identifies a Bedrock provider in [withProvider]. + // other providers use config.Provider* constants. providerBedrock = "bedrock" // defaults @@ -39,7 +39,6 @@ const ( defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" ) -// defaultTracer is the default OTel tracer used in integration tests. var defaultTracer = otel.Tracer("integrationtest") // bridgeTestServer wraps an httptest.Server running a RequestBridge. @@ -49,7 +48,6 @@ type bridgeTestServer struct { Bridge *aibridge.RequestBridge } -// bridgeOption configures a [bridgeTestServer]. type bridgeOption func(*bridgeConfig) type bridgeConfig struct { @@ -61,7 +59,6 @@ type bridgeConfig struct { metadata recorder.Metadata logger slog.Logger loggerSet bool - wrapRecorder bool } // newRequest creates a JSON POST request targeting the given path on this server. @@ -75,14 +72,14 @@ func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *h } // newDefaultProvider creates a Provider with default test configuration. -func newDefaultProvider(providerType, addr string) aibridge.Provider { +func newDefaultProvider(providerType string, addr string) aibridge.Provider { switch providerType { case config.ProviderAnthropic: return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) case config.ProviderOpenAI: return provider.NewOpenAI(openAICfg(addr, apiKey)) case providerBedrock: - return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) + return provider.NewAnthropic(anthropicCfg(addr, apiKey), bedrockCfg(addr)) default: panic("unknown provider type: " + providerType) } @@ -134,12 +131,6 @@ func withLogger(l slog.Logger) bridgeOption { return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } } -// withWrappedRecorder wraps the MockRecorder through aibridge.NewRecorder -// (the production recorder wrapper). Use when testing the recorder pipeline. -func withWrappedRecorder() bridgeOption { - return func(c *bridgeConfig) { c.wrapRecorder = true } -} - // setupInjectedToolTest abstracts common setup required for injected-tool integration tests. // Extra bridge options (e.g. [withProvider]) are appended after the built-in // MCP / tracer / actor options. When no provider option is given the default @@ -201,7 +192,6 @@ func setupInjectedToolTest( // newBridgeTestServer creates a fully configured test server running // a RequestBridge with sensible defaults: // - All standard providers (unless withProvider / withCustomProvider) -// - MockRecorder (raw, unless withWrappedRecorder) // - NoopMCPManager (unless withMCP) // - slogtest debug logger (unless withLogger) // - defaultTracer (unless withTracer) @@ -246,11 +236,9 @@ func newBridgeTestServer( mockRec := &testutil.MockRecorder{} var rec aibridge.Recorder = mockRec - if cfg.wrapRecorder { - rec = aibridge.NewRecorder(cfg.logger, cfg.tracer, func() (aibridge.Recorder, error) { - return mockRec, nil - }) - } + rec = aibridge.NewRecorder(cfg.logger, cfg.tracer, func() (aibridge.Recorder, error) { + return mockRec, nil + }) bridge, err := aibridge.NewRequestBridge( ctx, providers, rec, cfg.mcpProxy, diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index a2e74b2..d8625f5 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -100,7 +100,6 @@ func TestTraceAnthropic(t *testing.T) { opts := []bridgeOption{ withTracer(tracer), - withWrappedRecorder(), } if tc.bedrock { opts = append(opts, withProvider(providerBedrock)) @@ -219,7 +218,6 @@ func TestTraceAnthropicErr(t *testing.T) { opts := []bridgeOption{ withTracer(tracer), - withWrappedRecorder(), } if tc.bedrock { opts = append(opts, withProvider(providerBedrock)) @@ -390,18 +388,18 @@ func TestInjectedToolsTrace(t *testing.T) { func TestTraceOpenAI(t *testing.T) { cases := []struct { - name string - fixture []byte - streaming bool - path string - - expect []expectTrace + name string + fixture []byte + streaming bool + path string + + expect []expectTrace }{ { - name: "trace_openai_chat_streaming", - fixture: fixtures.OaiChatSimple, - streaming: true, - path: pathOpenAIChatCompletions, + name: "trace_openai_chat_streaming", + fixture: fixtures.OaiChatSimple, + streaming: true, + path: pathOpenAIChatCompletions, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -414,10 +412,10 @@ func TestTraceOpenAI(t *testing.T) { }, }, { - name: "trace_openai_chat_blocking", - fixture: fixtures.OaiChatSimple, - streaming: false, - path: pathOpenAIChatCompletions, + name: "trace_openai_chat_blocking", + fixture: fixtures.OaiChatSimple, + streaming: false, + path: pathOpenAIChatCompletions, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -430,10 +428,10 @@ func TestTraceOpenAI(t *testing.T) { }, }, { - name: "trace_openai_responses_streaming", - fixture: fixtures.OaiResponsesStreamingSimple, - streaming: true, - path: pathOpenAIResponses, + name: "trace_openai_responses_streaming", + fixture: fixtures.OaiResponsesStreamingSimple, + streaming: true, + path: pathOpenAIResponses, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -446,10 +444,10 @@ func TestTraceOpenAI(t *testing.T) { }, }, { - name: "trace_openai_responses_blocking", - fixture: fixtures.OaiResponsesBlockingSimple, - streaming: false, - path: pathOpenAIResponses, + name: "trace_openai_responses_blocking", + fixture: fixtures.OaiResponsesBlockingSimple, + streaming: false, + path: pathOpenAIResponses, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -477,7 +475,6 @@ func TestTraceOpenAI(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) ts := newBridgeTestServer(t, ctx, upstream.URL, withTracer(tracer), - withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) @@ -521,14 +518,14 @@ func TestTraceOpenAIErr(t *testing.T) { allowOverflow bool path string - expect []expectTrace - expectCode int + expect []expectTrace + expectCode int }{ { name: "trace_openai_chat_streaming_error", fixture: fixtures.OaiChatMidStreamError, streaming: true, - path: pathOpenAIChatCompletions, + path: pathOpenAIChatCompletions, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -544,7 +541,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_chat_blocking_error", fixture: fixtures.OaiChatNonStreamError, streaming: false, - path: pathOpenAIChatCompletions, + path: pathOpenAIChatCompletions, expectCode: http.StatusBadRequest, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -559,7 +556,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_responses_streaming_error", streaming: true, fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -572,10 +569,10 @@ func TestTraceOpenAIErr(t *testing.T) { }, }, { - name: "trace_openai_responses_blocking_error", - fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, - streaming: false, - path: pathOpenAIResponses, + name: "trace_openai_responses_blocking_error", + fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, + streaming: false, + path: pathOpenAIResponses, // Fixture returns http 200 response with wrong body // responses forward received response as is so // expected code == 200 even though ProcessRequest @@ -596,7 +593,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: true, allowOverflow: true, // 429 error causes retries - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectCode: http.StatusTooManyRequests, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -612,7 +609,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiResponsesBlockingHttpErr, streaming: false, - path: pathOpenAIResponses, + path: pathOpenAIResponses, expectCode: http.StatusUnauthorized, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -641,7 +638,6 @@ func TestTraceOpenAIErr(t *testing.T) { mockAPI.AllowOverflow = tc.allowOverflow ts := newBridgeTestServer(t, ctx, mockAPI.URL, withTracer(tracer), - withWrappedRecorder(), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) @@ -692,7 +688,6 @@ func TestTracePassthrough(t *testing.T) { ts := newBridgeTestServer(t, t.Context(), upstream.URL, withTracer(tracer), - withWrappedRecorder(), ) req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) From e7aaf920a3bd43997923d402ca026026d940a837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 12:47:28 +0000 Subject: [PATCH 13/32] remove actorID arg from setupInjectedToolTest --- internal/integrationtest/bridge_test.go | 4 +- internal/integrationtest/setupbridge.go | 143 ++++++++++++------------ internal/integrationtest/trace_test.go | 2 +- 3 files changed, 74 insertions(+), 75 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index d2f7171..f72fabc 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -943,7 +943,7 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathAnthropicMessages, anthropicToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() @@ -1410,7 +1410,7 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, defaultActorID, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index d371804..67aa7c7 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -71,20 +71,6 @@ func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *h return req } -// newDefaultProvider creates a Provider with default test configuration. -func newDefaultProvider(providerType string, addr string) aibridge.Provider { - switch providerType { - case config.ProviderAnthropic: - return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - case config.ProviderOpenAI: - return provider.NewOpenAI(openAICfg(addr, apiKey)) - case providerBedrock: - return provider.NewAnthropic(anthropicCfg(addr, apiKey), bedrockCfg(addr)) - default: - panic("unknown provider type: " + providerType) - } -} - // withProvider adds a default-configured provider of the given type. // When any provider option is used, the default "all providers" set is not created. func withProvider(providerType string) bridgeOption { @@ -131,64 +117,6 @@ func withLogger(l slog.Logger) bridgeOption { return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } } -// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. -// Extra bridge options (e.g. [withProvider]) are appended after the built-in -// MCP / tracer / actor options. When no provider option is given the default -// provider set (all providers) is used. -func setupInjectedToolTest( - t *testing.T, - fixture []byte, - streaming bool, - tracer trace.Tracer, - actorID string, - path string, - toolRequestValidatorFn func(*http.Request, []byte), - opts ...bridgeOption, -) (*testutil.MockRecorder, *mockMCP, *http.Response) { - t.Helper() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixture) - - // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. - firstResp := newFixtureResponse(fix) - toolResp := newFixtureToolResponse(fix) - toolResp.OnRequest = toolRequestValidatorFn - upstream := newMockUpstream(t, ctx, firstResp, toolResp) - - mockMCP := setupMCPForTest(t, tracer) - - allOpts := []bridgeOption{ - withMCP(mockMCP), - withTracer(tracer), - withActor(actorID, nil), - } - allOpts = append(allOpts, opts...) - ts := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) - - req := ts.newRequest(t, path, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - t.Cleanup(func() { - _ = resp.Body.Close() - }) - - // We must ALWAYS have 2 calls to the bridge for injected tool tests. - require.Eventually(t, func() bool { - return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) - - return ts.Recorder, mockMCP, resp -} - // newBridgeTestServer creates a fully configured test server running // a RequestBridge with sensible defaults: // - All standard providers (unless withProvider / withCustomProvider) @@ -260,3 +188,74 @@ func newBridgeTestServer( Bridge: bridge, } } + +// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. +// Extra bridge options (e.g. [withProvider]) are appended after the built-in +// MCP / tracer / actor options. When no provider option is given the default +// provider set (all providers) is used. +func setupInjectedToolTest( + t *testing.T, + fixture []byte, + streaming bool, + tracer trace.Tracer, + path string, + toolRequestValidatorFn func(*http.Request, []byte), + opts ...bridgeOption, +) (*testutil.MockRecorder, *mockMCP, *http.Response) { + t.Helper() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixture) + + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + firstResp := newFixtureResponse(fix) + toolResp := newFixtureToolResponse(fix) + toolResp.OnRequest = toolRequestValidatorFn + upstream := newMockUpstream(t, ctx, firstResp, toolResp) + + mockMCP := setupMCPForTest(t, tracer) + + allOpts := []bridgeOption{ + withMCP(mockMCP), + withTracer(tracer), + withActor(defaultActorID, nil), + } + allOpts = append(allOpts, opts...) + ts := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + + req := ts.newRequest(t, path, reqBody) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + // We must ALWAYS have 2 calls to the bridge for injected tool tests. + require.Eventually(t, func() bool { + return upstream.Calls.Load() == 2 + }, time.Second*10, time.Millisecond*50) + + return ts.Recorder, mockMCP, resp +} + +// newDefaultProvider creates a Provider with default test configuration. +func newDefaultProvider(providerType string, addr string) aibridge.Provider { + switch providerType { + case config.ProviderAnthropic: + return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) + case config.ProviderOpenAI: + return provider.NewOpenAI(openAICfg(addr, apiKey)) + case providerBedrock: + return provider.NewAnthropic(anthropicCfg(addr, apiKey), bedrockCfg(addr)) + default: + panic("unknown provider type: " + providerType) + } +} diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index d8625f5..eae10a6 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -355,7 +355,7 @@ func TestInjectedToolsTrace(t *testing.T) { } recorderClient, mockMCP, resp := setupInjectedToolTest( - t, tc.fixture, tc.streaming, tracer, defaultActorID, + t, tc.fixture, tc.streaming, tracer, tc.path, validatorFn, tc.opts..., ) defer resp.Body.Close() From 94e7ccc9d2e6505d76851e98cc6908763da6d2d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 12:59:11 +0000 Subject: [PATCH 14/32] use bridgeTestServer.newRequest --- internal/integrationtest/apidump_test.go | 6 ++---- internal/integrationtest/bridge_test.go | 4 +--- internal/integrationtest/metrics_test.go | 6 ++---- internal/integrationtest/setupbridge.go | 18 +++++++++--------- internal/integrationtest/trace_test.go | 6 ++---- 5 files changed, 16 insertions(+), 24 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index fa441d6..1bef469 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -198,9 +198,7 @@ func TestAPIDumpPassthrough(t *testing.T) { withCustomProvider(tc.newProvider(upstream.URL, dumpDir)), ) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL+tc.requestPath, nil) - require.NoError(t, err) - + req := ts.newRequest(t, tc.requestPath, nil) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -239,7 +237,7 @@ func TestAPIDumpPassthrough(t *testing.T) { require.NoError(t, err) dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData))) require.NoError(t, err) - require.Equal(t, http.MethodGet, dumpReq.Method) + require.Equal(t, http.MethodPost, dumpReq.Method) // Verify response dump. respDumpData, err := os.ReadFile(respDumpFile) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index f72fabc..d3620b9 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -403,9 +403,7 @@ func TestFallthrough(t *testing.T) { upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) ts := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) - req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s%s", ts.URL, tc.requestPath), nil) - require.NoError(t, err) - + req := ts.newRequest(t, tc.requestPath, nil) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index dffa15c..a9fb9fb 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -206,16 +206,14 @@ func TestMetrics_PassthroughCount(t *testing.T) { withMetrics(m), ) - req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) - require.NoError(t, err) - + req := ts.newRequest(t, "/openai/v1/models", nil) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues( - config.ProviderOpenAI, "/models", "GET")) + config.ProviderOpenAI, "/models", http.MethodPost)) require.Equal(t, 1.0, count) } diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 67aa7c7..9cf066c 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -41,15 +41,6 @@ const ( var defaultTracer = otel.Tracer("integrationtest") -// bridgeTestServer wraps an httptest.Server running a RequestBridge. -type bridgeTestServer struct { - *httptest.Server - Recorder *testutil.MockRecorder - Bridge *aibridge.RequestBridge -} - -type bridgeOption func(*bridgeConfig) - type bridgeConfig struct { providerBuilders []func(upstreamURL string) aibridge.Provider metrics *metrics.Metrics @@ -61,6 +52,13 @@ type bridgeConfig struct { loggerSet bool } +// bridgeTestServer wraps an httptest.Server running a RequestBridge. +type bridgeTestServer struct { + *httptest.Server + Recorder *testutil.MockRecorder + Bridge *aibridge.RequestBridge +} + // newRequest creates a JSON POST request targeting the given path on this server. func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *http.Request { t.Helper() @@ -71,6 +69,8 @@ func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *h return req } +type bridgeOption func(*bridgeConfig) + // withProvider adds a default-configured provider of the given type. // When any provider option is used, the default "all providers" set is not created. func withProvider(providerType string) bridgeOption { diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index eae10a6..1d7a7bd 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -690,9 +690,7 @@ func TestTracePassthrough(t *testing.T) { withTracer(tracer), ) - req, err := http.NewRequestWithContext(t.Context(), "GET", ts.URL+"/openai/v1/models", nil) - require.NoError(t, err) - + req := ts.newRequest(t, "/openai/v1/models", nil) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -704,7 +702,7 @@ func TestTracePassthrough(t *testing.T) { assert.Equal(t, spans[0].Name(), "Passthrough") want := []attribute.KeyValue{ - attribute.String(tracing.PassthroughMethod, "GET"), + attribute.String(tracing.PassthroughMethod, http.MethodPost), attribute.String(tracing.PassthroughUpstreamURL, upstream.URL+"/models"), attribute.String(tracing.PassthroughURL, "/models"), } From 9c0a25a24ae196a839346611d5a1e65086db326e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 13:53:24 +0000 Subject: [PATCH 15/32] refactor: update makeRequest call sites in trace_test.go for new *http.Response return type makeRequest now performs the HTTP request internally and returns *http.Response directly. Remove the intermediate http.Client creation, client.Do(req), and require.NoError(t, err) calls at all 5 call sites. --- internal/integrationtest/trace_test.go | 52 ++++++++++---------------- 1 file changed, 19 insertions(+), 33 deletions(-) diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 1d7a7bd..6f39b88 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -104,19 +104,16 @@ func TestTraceAnthropic(t *testing.T) { if tc.bedrock { opts = append(opts, withProvider(providerBedrock)) } - ts := newBridgeTestServer(t, ctx, upstream.URL, opts...) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fixtureReqBody, "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() - ts.Close() + bridgeServer.Close() - recorder := ts.Recorder + recorder := bridgeServer.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -222,23 +219,20 @@ func TestTraceAnthropicErr(t *testing.T) { if tc.bedrock { opts = append(opts, withProvider(providerBedrock)) } - ts := newBridgeTestServer(t, ctx, upstream.URL, opts...) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) if tc.streaming { require.Equal(t, http.StatusOK, resp.StatusCode) } else { require.Equal(t, tc.expectCode, resp.StatusCode) } defer resp.Body.Close() - ts.Close() + bridgeServer.Close() - recorder := ts.Recorder + recorder := bridgeServer.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -473,21 +467,18 @@ func TestTraceOpenAI(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withTracer(tracer), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, tc.path, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() - ts.Close() + bridgeServer.Close() - recorder := ts.Recorder + recorder := bridgeServer.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -636,22 +627,19 @@ func TestTraceOpenAIErr(t *testing.T) { mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix)) mockAPI.AllowOverflow = tc.allowOverflow - ts := newBridgeTestServer(t, ctx, mockAPI.URL, + bridgeServer := newBridgeTestServer(t, ctx, mockAPI.URL, withTracer(tracer), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, tc.path, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.Equal(t, tc.expectCode, resp.StatusCode) defer resp.Body.Close() - ts.Close() + bridgeServer.Close() - recorder := ts.Recorder + recorder := bridgeServer.Recorder require.Equal(t, 1, len(recorder.RecordedInterceptions())) intcID := recorder.RecordedInterceptions()[0].ID @@ -686,16 +674,14 @@ func TestTracePassthrough(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - ts := newBridgeTestServer(t, t.Context(), upstream.URL, + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, withTracer(tracer), ) - req := ts.newRequest(t, "/openai/v1/models", nil) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - ts.Close() + bridgeServer.Close() spans := sr.Ended() require.Len(t, spans, 1) From e93d096a35b9c190d2630ef71ab7cca97ae02cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 13:53:47 +0000 Subject: [PATCH 16/32] refactor: update makeRequest call sites in responses_test.go makeRequest now returns *http.Response directly instead of *http.Request. Updated all 6 call sites: - Removed intermediate http.DefaultClient.Do / client.Do calls - Removed require.NoError for the HTTP request error - Collapsed header mutation site to pass http.Header as extra arg --- internal/integrationtest/responses_test.go | 65 +++++++--------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 91a978c..612a821 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -334,14 +334,9 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - req := ts.newRequest(t, pathOpenAIResponses, fix.Request()) - req.Header.Set("User-Agent", tc.userAgent) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}}) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) got, err := io.ReadAll(resp.Body) @@ -353,7 +348,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Equal(t, string(fix.NonStreaming()), string(got)) } - interceptions := ts.Recorder.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) intc := interceptions[0] require.Equal(t, intc.InitiatorID, defaultActorID) @@ -362,7 +357,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Equal(t, tc.userAgent, intc.UserAgent) require.Equal(t, string(tc.expectedClient), intc.Client) - recordedPrompts := ts.Recorder.RecordedPromptUsages() + recordedPrompts := bridgeServer.Recorder.RecordedPromptUsages() if tc.expectPromptRecorded != "" { require.Len(t, recordedPrompts, 1) promptEq := func(pur *recorder.PromptUsageRecord) bool { return pur.Prompt == tc.expectPromptRecorded } @@ -371,7 +366,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Empty(t, recordedPrompts) } - recordedTools := ts.Recorder.RecordedToolUsages() + recordedTools := bridgeServer.Recorder.RecordedToolUsages() if tc.expectToolRecorded != nil { require.Len(t, recordedTools, 1) recordedTools[0].InterceptionID = tc.expectToolRecorded.InterceptionID // ignore interception id (interception id is not constant and response doesn't contain it) @@ -381,7 +376,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Empty(t, recordedTools) } - recordedTokens := ts.Recorder.RecordedTokenUsages() + recordedTokens := bridgeServer.Recorder.RecordedTokenUsages() if tc.expectTokenUsage != nil { require.Len(t, recordedTokens, 1) recordedTokens[0].InterceptionID = tc.expectTokenUsage.InterceptionID // ignore interception id @@ -425,15 +420,11 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { })) t.Cleanup(upstream.Close) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) - req := ts.newRequest(t, pathOpenAIResponses, reqBytes) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) defer resp.Body.Close() require.Equal(t, "application/json", resp.Header.Get("Content-Type")) @@ -537,15 +528,11 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { })) t.Cleanup(upstream.Close) - ts := newBridgeTestServer(t, ctx, upstream.URL) - - req := ts.newRequest(t, pathOpenAIResponses, []byte(tc.request)) - client := &http.Client{} + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, []byte(tc.request)) defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) }) } @@ -597,14 +584,10 @@ func TestClientAndConnectionError(t *testing.T) { t.Cleanup(cancel) // tc.addr may be an intentionally invalid URL; use withCustomProvider. - ts := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) + bridgeServer := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) reqBytes := responsesRequestBytes(t, tc.streaming) - req := ts.newRequest(t, pathOpenAIResponses, reqBytes) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) defer resp.Body.Close() require.Equal(t, "application/json", resp.Header.Get("Content-Type")) @@ -613,7 +596,7 @@ func TestClientAndConnectionError(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) requireResponsesError(t, http.StatusInternalServerError, tc.errContains, body) - require.Empty(t, ts.Recorder.RecordedPromptUsages()) + require.Empty(t, bridgeServer.Recorder.RecordedPromptUsages()) }) } } @@ -679,14 +662,10 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) reqBytes := responsesRequestBytes(t, tc.streaming) - req := ts.newRequest(t, pathOpenAIResponses, reqBytes) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) defer resp.Body.Close() require.Equal(t, tc.statusCode, resp.StatusCode) @@ -861,11 +840,9 @@ func TestResponsesInjectedTool(t *testing.T) { mockMCP.setToolError(tc.mcpToolName, tc.expectToolError) } - ts := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP)) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP)) - req := ts.newRequest(t, pathOpenAIResponses, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) @@ -882,7 +859,7 @@ func TestResponsesInjectedTool(t *testing.T) { require.Len(t, invocations, 1, "expected MCP tool to be invoked once") // Verify the injected tool usage was recorded. - toolUsages := ts.Recorder.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) require.Equal(t, tc.mcpToolName, toolUsages[0].Tool) require.Equal(t, tc.expectToolArgs, toolUsages[0].Args) @@ -892,11 +869,11 @@ func TestResponsesInjectedTool(t *testing.T) { } // Verify prompt was recorded. - prompts := ts.Recorder.RecordedPromptUsages() + prompts := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, prompts, 1) require.Equal(t, tc.expectPrompt, prompts[0].Prompt) - tokenUsages := ts.Recorder.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, len(tc.expectTokenUsages)) for i := range tokenUsages { tokenUsages[i].InterceptionID = "" // ignore interception ID and time creation when comparing From 52b17c640768cf02ff0c61da03b02c2c77632540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 13:54:59 +0000 Subject: [PATCH 17/32] refactor: update makeRequest call sites in bridge_test.go makeRequest now returns *http.Response directly (performs HTTP request internally). Updated all 14 call sites: - Changed req := to resp := - Removed http.DefaultClient.Do(req) and require.NoError(t, err) lines - Removed client := &http.Client{} and client.Do(req) patterns - Converted header mutation case to pass http.Header as extra arg --- internal/integrationtest/bridge_test.go | 138 +++++++++--------------- 1 file changed, 51 insertions(+), 87 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index d3620b9..139aff7 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -169,15 +169,12 @@ func TestSimple(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) // When: calling the "API server" with the fixture's request body. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := ts.newRequest(t, tc.path, reqBody) - req.Header.Set("User-Agent", tc.userAgent) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -195,7 +192,7 @@ func TestSimple(t *testing.T) { resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Then: I expect the prompt to have been tracked. - promptUsages := ts.Recorder.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.NotEmpty(t, promptUsages, "no prompts tracked") assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") @@ -206,18 +203,18 @@ func TestSimple(t *testing.T) { require.NoError(t, err, "failed to retrieve response ID") require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) - tokenUsages := ts.Recorder.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.GreaterOrEqual(t, len(tokenUsages), 1) require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) // Validate user agent and client have been recorded. - interceptions := ts.Recorder.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) assert.Equal(t, id, interceptions[0].ID) assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -401,11 +398,9 @@ func TestFallthrough(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) - ts := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) - req := ts.newRequest(t, tc.requestPath, nil) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) @@ -485,19 +480,17 @@ func TestErrorHandling(t *testing.T) { fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := ts.newRequest(t, tc.path, reqBody) - resp, err := http.DefaultClient.Do(req) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) t.Cleanup(func() { _ = resp.Body.Close() }) - require.NoError(t, err) tc.responseHandlerFn(resp) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -559,16 +552,14 @@ func TestErrorHandling(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - req := ts.newRequest(t, tc.path, fix.Request()) - resp, err := http.DefaultClient.Do(req) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) t.Cleanup(func() { _ = resp.Body.Close() }) - require.NoError(t, err) - ts.Close() + bridgeServer.Close() tc.responseHandlerFn(resp) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -618,15 +609,13 @@ func TestStableRequestEncoding(t *testing.T) { } upstream := newMockUpstream(t, ctx, responses...) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP), ) // Make multiple requests and verify they all have identical payloads. for range count { - req := ts.newRequest(t, tc.path, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) _ = resp.Body.Close() } @@ -690,11 +679,9 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - req := ts.newRequest(t, tc.path, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -807,7 +794,7 @@ func TestActorHeaders(t *testing.T) { t.Cleanup(srv.Close) metadataKey := "Username" - ts := newBridgeTestServer(t, ctx, srv.URL, + bridgeServer := newBridgeTestServer(t, ctx, srv.URL, withCustomProvider(tc.createProviderFn(srv.URL, apiKey, send)), withActor(defaultActorID, recorder.Metadata{ metadataKey: actorUsername, @@ -818,9 +805,7 @@ func TestActorHeaders(t *testing.T) { reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, tc.path, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.NotEmpty(t, receivedHeaders) defer resp.Body.Close() @@ -880,15 +865,12 @@ func TestAnthropicMessages(t *testing.T) { fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -907,13 +889,13 @@ func TestAnthropicMessages(t *testing.T) { // One for message_start, one for message_delta. expectedTokenRecordings = 2 } - tokenUsages := ts.Recorder.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, expectedTokenRecordings) - assert.EqualValues(t, tc.expectedInputTokens, ts.Recorder.TotalInputTokens(), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, ts.Recorder.TotalOutputTokens(), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") - toolUsages := ts.Recorder.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "Read", toolUsages[0].Tool) assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) @@ -923,11 +905,11 @@ func TestAnthropicMessages(t *testing.T) { require.Contains(t, args, "file_path") assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - promptUsages := ts.Recorder.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1150,7 +1132,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { fix := fixtures.Parse(t, fixtures.AntSimple) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mcpMgr), ) @@ -1158,10 +1140,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) _ = resp.Body.Close() @@ -1210,7 +1189,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { // Create a mock server that captures the request body sent upstream. upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Inject adaptive thinking into the fixture request. reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) @@ -1218,9 +1197,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) _, _ = io.ReadAll(resp.Body) _ = resp.Body.Close() @@ -1268,16 +1245,12 @@ func TestOpenAIChatCompletions(t *testing.T) { fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - ts := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() @@ -1295,12 +1268,12 @@ func TestOpenAIChatCompletions(t *testing.T) { assert.Equal(t, "[DONE]", lastEvent.Data) } - tokenUsages := ts.Recorder.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, ts.Recorder.TotalInputTokens(), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, ts.Recorder.TotalOutputTokens(), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") - toolUsages := ts.Recorder.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "read_file", toolUsages[0].Tool) assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) @@ -1308,11 +1281,11 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Contains(t, toolUsages[0].Args, "path") assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - promptUsages := ts.Recorder.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1352,18 +1325,14 @@ func TestOpenAIChatCompletions(t *testing.T) { // Setup MCP proxies with the tool from the fixture mockMCP := setupMCPForTest(t, defaultTracer) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP), ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - req := ts.newRequest(t, pathOpenAIChatCompletions, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) // Verify SSE headers are sent correctly @@ -1390,11 +1359,11 @@ func TestOpenAIChatCompletions(t *testing.T) { } // Verify tool usage was recorded - toolUsages := ts.Recorder.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, mockToolName, toolUsages[0].Tool) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1546,14 +1515,12 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - ts := newBridgeTestServer(t, ctx, "http://unused", + bridgeServer := newBridgeTestServer(t, ctx, "http://unused", withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), withLogger(newLogger(t)), ) - req := ts.newRequest(t, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) defer resp.Body.Close() require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -1584,7 +1551,7 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), withLogger(newLogger(t)), ) @@ -1593,10 +1560,7 @@ func TestAWSBedrockIntegration(t *testing.T) { // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := ts.newRequest(t, pathAnthropicMessages, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) defer resp.Body.Close() // For streaming responses, consume the body to allow the stream to complete. @@ -1620,10 +1584,10 @@ func TestAWSBedrockIntegration(t *testing.T) { require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - interceptions := ts.Recorder.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) From 97559052e5d19d20b68d62ac0603b11fee19fa31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 14:26:31 +0000 Subject: [PATCH 18/32] bridgeServer.makeRequest + tracerSetup + TestSessionIDTracking cleanup --- internal/integrationtest/apidump_test.go | 22 ++-- internal/integrationtest/bridge_test.go | 70 +++-------- .../integrationtest/circuit_breaker_test.go | 116 ++++++++---------- internal/integrationtest/metrics_test.go | 40 +++--- internal/integrationtest/setupbridge.go | 26 ++-- internal/integrationtest/trace_test.go | 50 ++++---- 6 files changed, 133 insertions(+), 191 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 1bef469..4640a45 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -70,25 +70,23 @@ func TestAPIDump(t *testing.T) { // Create temp dir for API dumps. dumpDir := t.TempDir() - ts := newBridgeTestServer(t, ctx, srv.URL, + bridgeServer := newBridgeTestServer(t, ctx, srv.URL, withCustomProvider(tc.newProvider(srv.URL, dumpDir)), ) - req := ts.newRequest(t, tc.path, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) // Verify dump files were created. - interceptions := ts.Recorder.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) interceptionID := interceptions[0].ID // Find dump files for this interception by walking the dump directory. var reqDumpFile, respDumpFile string - err = filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error { + err := filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } @@ -137,7 +135,7 @@ func TestAPIDump(t *testing.T) { expectedRespBody := fix.NonStreaming() require.JSONEq(t, string(expectedRespBody), string(dumpRespBody), "response body JSON should match semantically") - ts.Recorder.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } } @@ -194,19 +192,17 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withCustomProvider(tc.newProvider(upstream.URL, dumpDir)), ) - req := ts.newRequest(t, tc.requestPath, nil) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) defer resp.Body.Close() // Find dump files in the passthrough directory. passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough") var reqDumpFile, respDumpFile string - err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { + err := filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } @@ -237,7 +233,7 @@ func TestAPIDumpPassthrough(t *testing.T) { require.NoError(t, err) dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData))) require.NoError(t, err) - require.Equal(t, http.MethodPost, dumpReq.Method) + require.Equal(t, http.MethodGet, dumpReq.Method) // Verify response dump. respDumpData, err := os.ReadFile(respDumpFile) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 139aff7..813cf30 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -229,8 +229,8 @@ func TestSessionIDTracking(t *testing.T) { fixture []byte expectedClient aibridge.Client sessionID string - configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error) - createRequest func(t *testing.T, baseURL string, body []byte) *http.Request + header http.Header + mutateBody func(t *testing.T, body []byte) []byte }{ // Session in header. { @@ -238,18 +238,9 @@ func TestSessionIDTracking(t *testing.T) { fixture: fixtures.AntSimple, expectedClient: aibridge.ClientMux, sessionID: "mux-workspace-321", - configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - }, - createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request { - t.Helper() - req := createAnthropicMessagesReq(t, baseURL, body) - req.Header.Set("User-Agent", "mux/1.0.0") - req.Header.Set("X-Mux-Workspace-Id", "mux-workspace-321") - return req + header: http.Header{ + "User-Agent": []string{"mux/1.0.0"}, + "X-Mux-Workspace-Id": []string{"mux-workspace-321"}, }, }, // Session in body. @@ -258,21 +249,16 @@ func TestSessionIDTracking(t *testing.T) { fixture: fixtures.AntSimple, expectedClient: aibridge.ClientClaudeCode, sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", - configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + header: http.Header{ + "User-Agent": []string{"claude-cli/2.0.67 (external, cli)"}, }, - createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request { + mutateBody: func(t *testing.T, body []byte) []byte { t.Helper() // Claude Code embeds the session ID in metadata.user_id within the body. body, err := sjson.SetBytes(body, "metadata.user_id", "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479") require.NoError(t, err) - req := createAnthropicMessagesReq(t, baseURL, body) - req.Header.Set("User-Agent", "claude-cli/2.0.67 (external, cli)") - return req + return body }, }, // No session. @@ -280,17 +266,8 @@ func TestSessionIDTracking(t *testing.T) { name: "zed", fixture: fixtures.AntSimple, expectedClient: aibridge.ClientZed, - configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - }, - createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request { - t.Helper() - req := createAnthropicMessagesReq(t, baseURL, body) - req.Header.Set("User-Agent", "Zed/0.219.4+stable.119.abc123 (macos; aarch64)") - return req + header: http.Header{ + "User-Agent": []string{"Zed/0.219.4+stable.119.abc123 (macos; aarch64)"}, }, }, } @@ -303,30 +280,23 @@ func TestSessionIDTracking(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic)) - b, err := tc.configureFunc(t, upstream.URL, recorderClient) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) + reqBody := fix.Request() + if tc.mutateBody != nil { + reqBody = tc.mutateBody(t, reqBody) } - mockSrv.Start() - req := tc.createRequest(t, mockSrv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() // Drain the body to let the stream complete. - _, err = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - interceptions := recorderClient.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception") assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) @@ -337,7 +307,7 @@ func TestSessionIDTracking(t *testing.T) { assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID) } - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } } diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 37f7d8c..9c3ff55 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -44,7 +44,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody string successBody string requestBody string - setupHeaders func(req *http.Request) + headers http.Header path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string @@ -61,9 +61,9 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody: anthropicRateLimitError, successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") + headers: http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { @@ -82,9 +82,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody: openAIRateLimitError, successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer test-key") - }, + headers: http.Header{"Authorization": {"Bearer test-key"}}, path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ @@ -131,18 +129,15 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { } ctx := t.Context() - ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) - makeRequest := func() *http.Response { - req := ts.newRequest(t, tc.path, []byte(tc.requestBody)) - tc.setupHeaders(req) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(resp.Body) + doRequest := func() *http.Response { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) resp.Body.Close() return resp @@ -151,14 +146,14 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { // Phase 1: Trip the circuit breaker // First FailureThreshold requests hit upstream, get 429 for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load()) // Phase 2: Verify circuit is open // Request should be blocked by circuit breaker (no upstream call) - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") @@ -180,7 +175,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { // Phase 4: Recovery - request in half-open state should succeed and close circuit upstreamCallsBefore := upstreamCalls.Load() - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") @@ -191,7 +186,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { // Phase 5: Verify circuit is fully functional again // Multiple requests should all succeed and reach upstream for i := 0; i < 3; i++ { - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes") } @@ -214,7 +209,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name string errorBody string requestBody string - setupHeaders func(req *http.Request) + headers http.Header path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string @@ -230,9 +225,9 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { expectModel: "claude-sonnet-4-20250514", errorBody: anthropicRateLimitError, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") + headers: http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { @@ -250,9 +245,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { expectModel: "gpt-4o", errorBody: openAIRateLimitError, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer test-key") - }, + headers: http.Header{"Authorization": {"Bearer test-key"}}, path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ @@ -290,18 +283,15 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { } ctx := t.Context() - ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) - makeRequest := func() *http.Response { - req := ts.newRequest(t, tc.path, []byte(tc.requestBody)) - tc.setupHeaders(req) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(resp.Body) + doRequest := func() *http.Response { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) resp.Body.Close() return resp @@ -309,12 +299,12 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { // Phase 1: Trip the circuit for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } // Verify circuit is open - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) @@ -325,12 +315,12 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { // Phase 3: Request in half-open state fails, circuit should re-open upstreamCallsBefore := upstreamCalls.Load() - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Circuit should be open again - next request should be rejected immediately - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") @@ -354,7 +344,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody string successBody string requestBody string - setupHeaders func(req *http.Request) + headers http.Header path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string @@ -371,9 +361,9 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody: anthropicRateLimitError, successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") + headers: http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { @@ -392,9 +382,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody: openAIRateLimitError, successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer test-key") - }, + headers: http.Header{"Authorization": {"Bearer test-key"}}, path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ @@ -442,18 +430,15 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { } ctx := t.Context() - ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) - makeRequest := func() *http.Response { - req := ts.newRequest(t, tc.path, []byte(tc.requestBody)) - tc.setupHeaders(req) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(resp.Body) + doRequest := func() *http.Response { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) resp.Body.Close() return resp @@ -461,12 +446,12 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { // Phase 1: Trip the circuit for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } // Verify circuit is open - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) // Phase 2: Wait for half-open state and switch upstream to success @@ -483,7 +468,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp := makeRequest() + resp := doRequest() responses <- resp.StatusCode }() } @@ -561,7 +546,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { MaxRequests: 1, } ctx := t.Context() - ts := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, withCustomProvider(provider.NewAnthropic(config.Anthropic{ BaseURL: mockUpstream.URL, Key: "test-key", @@ -571,14 +556,13 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { withActor("test-user-id", nil), ) - makeRequest := func(model string) *http.Response { + doRequest := func(model string) *http.Response { body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - req := ts.newRequest(t, pathAnthropicMessages, []byte(body)) - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(resp.Body) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, + }) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) resp.Body.Close() return resp @@ -586,13 +570,13 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { // Phase 1: Trip the circuit for sonnet model for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest("claude-sonnet-4-20250514") + resp := doRequest("claude-sonnet-4-20250514") assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load()) // Verify sonnet circuit is open - resp := makeRequest("claude-sonnet-4-20250514") + resp := doRequest("claude-sonnet-4-20250514") assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open") assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open") @@ -604,13 +588,13 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)") // Phase 2: Haiku model should still work (independent circuit) - resp = makeRequest("claude-3-5-haiku-20241022") + resp = doRequest("claude-3-5-haiku-20241022") assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open") assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream") // Make multiple haiku requests - all should succeed for i := 0; i < 3; i++ { - resp = makeRequest("claude-3-5-haiku-20241022") + resp = doRequest("claude-3-5-haiku-20241022") assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed") } assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream") @@ -626,7 +610,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { time.Sleep(cbConfig.Timeout + 10*time.Millisecond) sonnetShouldFail.Store(false) - resp = makeRequest("claude-sonnet-4-20250514") + resp = doRequest("claude-sonnet-4-20250514") assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout") // Verify sonnet circuit is now closed diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index a9fb9fb..3aec692 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -1,6 +1,7 @@ package integrationtest import ( + "bytes" "context" "io" "net/http" @@ -120,13 +121,11 @@ func TestMetrics_Interception(t *testing.T) { upstream.AllowOverflow = tc.allowOverflow m := aibridge.NewMetrics(prometheus.NewRegistry()) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), ) - req := ts.newRequest(t, tc.path, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) @@ -156,7 +155,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { t.Cleanup(srv.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - ts := newBridgeTestServer(t, ctx, srv.URL, + bridgeServer := newBridgeTestServer(t, ctx, srv.URL, withMetrics(m), ) @@ -164,7 +163,8 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := ts.newRequest(t, pathAnthropicMessages, fix.Request()) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, bridgeServer.URL+pathAnthropicMessages, bytes.NewReader(fix.Request())) + req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() @@ -202,18 +202,16 @@ func TestMetrics_PassthroughCount(t *testing.T) { t.Cleanup(upstream.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - ts := newBridgeTestServer(t, t.Context(), upstream.URL, + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, withMetrics(m), ) - req := ts.newRequest(t, "/openai/v1/models", nil) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues( - config.ProviderOpenAI, "/models", http.MethodPost)) + config.ProviderOpenAI, "/models", "GET")) require.Equal(t, 1.0, count) } @@ -227,13 +225,11 @@ func TestMetrics_PromptCount(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), ) - req := ts.newRequest(t, pathOpenAIChatCompletions, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) @@ -253,13 +249,11 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), ) - req := ts.newRequest(t, pathOpenAIChatCompletions, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) @@ -284,14 +278,12 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // Setup mocked MCP server & tools. mockMCP := setupMCPForTest(t, defaultTracer) - ts := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMetrics(m), withMCP(mockMCP), ) - req := ts.newRequest(t, pathAnthropicMessages, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) @@ -301,7 +293,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) - recorder := ts.Recorder + recorder := bridgeServer.Recorder require.Len(t, recorder.ToolUsages(), 1) require.True(t, recorder.ToolUsages()[0].Injected) require.NotNil(t, recorder.ToolUsages()[0].ServerURL) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 9cf066c..2e15a0d 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -59,14 +59,24 @@ type bridgeTestServer struct { Bridge *aibridge.RequestBridge } -// newRequest creates a JSON POST request targeting the given path on this server. -func (s *bridgeTestServer) newRequest(t *testing.T, path string, body []byte) *http.Request { +// makeRequest builds and executes an HTTP request against this server. +// Optional headers are applied after the default Content-Type. +func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) *http.Response { t.Helper() - req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, s.URL+path, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(t.Context(), method, s.URL+path, bytes.NewReader(body)) require.NoError(t, err) req.Header.Set("Content-Type", "application/json") - return req + for _, h := range header { + for k, vals := range h { + for _, v := range vals { + req.Header.Set(k, v) + } + } + } + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + return resp } type bridgeOption func(*bridgeConfig) @@ -224,15 +234,13 @@ func setupInjectedToolTest( withActor(defaultActorID, nil), } allOpts = append(allOpts, opts...) - ts := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := ts.newRequest(t, path, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) t.Cleanup(func() { _ = resp.Body.Close() @@ -243,7 +251,7 @@ func setupInjectedToolTest( return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) - return ts.Recorder, mockMCP, resp + return bridgeServer.Recorder, mockMCP, resp } // newDefaultProvider creates a Provider with default test configuration. diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 6f39b88..d3e9542 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -18,6 +18,7 @@ import ( "github.com/tidwall/sjson" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" + oteltrace "go.opentelemetry.io/otel/trace" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" ) @@ -29,6 +30,18 @@ type expectTrace struct { status codes.Code } +func setupTracer(t *testing.T) (*tracetest.SpanRecorder, oteltrace.Tracer) { + t.Helper() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + t.Cleanup(func() { + _ = tp.Shutdown(t.Context()) + }) + + return sr, tp.Tracer(t.Name()) +} + func TestTraceAnthropic(t *testing.T) { expectNonStreaming := []expectTrace{ {"Intercept", 1, codes.Unset}, @@ -91,10 +104,7 @@ func TestTraceAnthropic(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) @@ -205,10 +215,7 @@ func TestTraceAnthropicErr(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) @@ -336,10 +343,7 @@ func TestInjectedToolsTrace(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) var validatorFn func(*http.Request, []byte) if tc.expectProvider == config.ProviderAnthropic { @@ -460,10 +464,7 @@ func TestTraceOpenAI(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) @@ -618,10 +619,7 @@ func TestTraceOpenAIErr(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) @@ -669,10 +667,7 @@ func TestTracePassthrough(t *testing.T) { upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, withTracer(tracer), @@ -688,7 +683,7 @@ func TestTracePassthrough(t *testing.T) { assert.Equal(t, spans[0].Name(), "Passthrough") want := []attribute.KeyValue{ - attribute.String(tracing.PassthroughMethod, http.MethodPost), + attribute.String(tracing.PassthroughMethod, "GET"), attribute.String(tracing.PassthroughUpstreamURL, upstream.URL+"/models"), attribute.String(tracing.PassthroughURL, "/models"), } @@ -697,10 +692,7 @@ func TestTracePassthrough(t *testing.T) { } func TestNewServerProxyManagerTraces(t *testing.T) { - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) serverName := "serverName" mockMCP := setupMCPForTestWithName(t, serverName, tracer) From b48a1536473c06512c2ce9795eeb2fa992b0419d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 14:28:42 +0000 Subject: [PATCH 19/32] cleanup --- internal/integrationtest/apidump_test.go | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 4640a45..2349640 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -25,15 +25,15 @@ func TestAPIDump(t *testing.T) { t.Parallel() cases := []struct { - name string - fixture []byte - newProvider func(addr, dumpDir string) aibridge.Provider - path string + name string + fixture []byte + providerFunc func(addr, dumpDir string) aibridge.Provider + path string }{ { name: "anthropic", fixture: fixtures.AntSimple, - newProvider: func(addr, dumpDir string) aibridge.Provider { + providerFunc: func(addr, dumpDir string) aibridge.Provider { return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, path: pathAnthropicMessages, @@ -41,7 +41,7 @@ func TestAPIDump(t *testing.T) { { name: "openai_chat_completions", fixture: fixtures.OaiChatSimple, - newProvider: func(addr, dumpDir string) aibridge.Provider { + providerFunc: func(addr, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, path: pathOpenAIChatCompletions, @@ -49,7 +49,7 @@ func TestAPIDump(t *testing.T) { { name: "openai_responses", fixture: fixtures.OaiResponsesBlockingSimple, - newProvider: func(addr, dumpDir string) aibridge.Provider { + providerFunc: func(addr, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, path: pathOpenAIResponses, @@ -71,7 +71,7 @@ func TestAPIDump(t *testing.T) { dumpDir := t.TempDir() bridgeServer := newBridgeTestServer(t, ctx, srv.URL, - withCustomProvider(tc.newProvider(srv.URL, dumpDir)), + withCustomProvider(tc.providerFunc(srv.URL, dumpDir)), ) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) @@ -147,13 +147,13 @@ func TestAPIDumpPassthrough(t *testing.T) { cases := []struct { name string - newProvider func(addr string, dumpDir string) aibridge.Provider + providerFunc func(addr string, dumpDir string) aibridge.Provider requestPath string expectDumpName string }{ { name: "anthropic", - newProvider: func(addr string, dumpDir string) aibridge.Provider { + providerFunc: func(addr string, dumpDir string) aibridge.Provider { return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, requestPath: "/anthropic/v1/models", @@ -161,7 +161,7 @@ func TestAPIDumpPassthrough(t *testing.T) { }, { name: "openai", - newProvider: func(addr string, dumpDir string) aibridge.Provider { + providerFunc: func(addr string, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, requestPath: "/openai/v1/models", @@ -169,7 +169,7 @@ func TestAPIDumpPassthrough(t *testing.T) { }, { name: "copilot", - newProvider: func(addr string, dumpDir string) aibridge.Provider { + providerFunc: func(addr string, dumpDir string) aibridge.Provider { return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) }, requestPath: "/copilot/models", @@ -193,7 +193,7 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, - withCustomProvider(tc.newProvider(upstream.URL, dumpDir)), + withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)), ) resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) From 8860414976ceb0f88d4c54fca58dad983f340afd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 14:41:40 +0000 Subject: [PATCH 20/32] reorder tests in bridge_test.go to original order --- internal/integrationtest/bridge_test.go | 2111 +++++++++++------------ 1 file changed, 1055 insertions(+), 1056 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 813cf30..16fb06e 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -37,476 +37,285 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func TestSimple(t *testing.T) { +func TestAnthropicMessages(t *testing.T) { t.Parallel() - getAnthropicResponseID := func(streaming bool, resp *http.Response) (string, error) { - if streaming { - decoder := ssestream.NewDecoder(resp) - stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) - var message anthropic.Message - for stream.Next() { - event := stream.Current() - if err := message.Accumulate(event); err != nil { - return "", fmt.Errorf("accumulate event: %w", err) - } - } - if stream.Err() != nil { - return "", fmt.Errorf("stream error: %w", stream.Err()) - } - return message.ID, nil - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("read body: %w", err) - } - - var message anthropic.Message - if err := json.Unmarshal(body, &message); err != nil { - return "", fmt.Errorf("unmarshal response: %w", err) - } - return message.ID, nil - } - - getOpenAIResponseID := func(streaming bool, resp *http.Response) (string, error) { - if streaming { - // Parse the response stream. - decoder := oaissestream.NewDecoder(resp) - stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) - var message openai.ChatCompletionAccumulator - for stream.Next() { - chunk := stream.Current() - message.AddChunk(chunk) - } - if stream.Err() != nil { - return "", fmt.Errorf("stream error: %w", stream.Err()) - } - return message.ID, nil - } - - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("read body: %w", err) - } + t.Run("single builtin tool", func(t *testing.T) { + t.Parallel() - var message openai.ChatCompletion - if err := json.Unmarshal(body, &message); err != nil { - return "", fmt.Errorf("unmarshal response: %w", err) + cases := []struct { + streaming bool + expectedInputTokens int + expectedOutputTokens int + expectedToolCallID string + }{ + { + streaming: true, + expectedInputTokens: 2, + expectedOutputTokens: 66, + expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", + }, + { + streaming: false, + expectedInputTokens: 5, + expectedOutputTokens: 84, + expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", + }, } - return message.ID, nil - } - testCases := []struct { - name string - fixture []byte - basePath string - expectedPath string - getResponseIDFunc func(streaming bool, resp *http.Response) (string, error) - path string - expectedMsgID string - userAgent string - expectedClient aibridge.Client - }{ - { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - basePath: "", - expectedPath: "/v1/messages", - getResponseIDFunc: getAnthropicResponseID, - path: pathAnthropicMessages, - expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", - userAgent: "claude-cli/2.0.67 (external, cli)", - expectedClient: aibridge.ClientClaudeCode, - }, - { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - basePath: "", - expectedPath: "/chat/completions", - getResponseIDFunc: getOpenAIResponseID, - path: pathOpenAIChatCompletions, - expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", - userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", - expectedClient: aibridge.ClientCodex, - }, - { - name: config.ProviderAnthropic + "_baseURL_path", - fixture: fixtures.AntSimple, - basePath: "/api", - expectedPath: "/api/v1/messages", - getResponseIDFunc: getAnthropicResponseID, - path: pathAnthropicMessages, - expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", - userAgent: "GitHubCopilotChat/0.37.2026011603", - expectedClient: aibridge.ClientCopilotVSC, - }, - { - name: config.ProviderOpenAI + "_baseURL_path", - fixture: fixtures.OaiChatSimple, - basePath: "/api", - expectedPath: "/api/chat/completions", - getResponseIDFunc: getOpenAIResponseID, - path: pathOpenAIChatCompletions, - expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", - userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", - expectedClient: aibridge.ClientZed, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + for _, tc := range cases { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Parallel() - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - // When: calling the "API server" with the fixture's request body. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - // Then: I expect the upstream request to have the correct path. - received := upstream.receivedRequests() - require.Len(t, received, 1) - require.Equal(t, tc.expectedPath, received[0].Path) + // Make API call to aibridge for Anthropic /v1/messages + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() - // Then: I expect a non-empty response. - bodyBytes, err := io.ReadAll(resp.Body) - require.NoError(t, err) - assert.NotEmpty(t, bodyBytes, "should have received response body") + // Response-specific checks. + if tc.streaming { + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) - // Reset the body after being read. - resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Ensure the message starts and completes, at a minimum. + assert.Contains(t, sp.AllEvents(), "message_start") + assert.Contains(t, sp.AllEvents(), "message_stop") + } - // Then: I expect the prompt to have been tracked. - promptUsages := bridgeServer.Recorder.RecordedPromptUsages() - require.NotEmpty(t, promptUsages, "no prompts tracked") - assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") + expectedTokenRecordings := 1 + if tc.streaming { + // One for message_start, one for message_delta. + expectedTokenRecordings = 2 + } + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, expectedTokenRecordings) - // Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider. - // The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting - // multiple messages in response to a single request. - id, err := tc.getResponseIDFunc(streaming, resp) - require.NoError(t, err, "failed to retrieve response ID") - require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) + assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") - tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() - require.GreaterOrEqual(t, len(tokenUsages), 1) - require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) + toolUsages := bridgeServer.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "Read", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) + require.IsType(t, json.RawMessage{}, toolUsages[0].Args) + var args map[string]any + require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) + require.Contains(t, args, "file_path") + assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - // Validate user agent and client have been recorded. - interceptions := bridgeServer.Recorder.RecordedInterceptions() - require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) - assert.Equal(t, id, interceptions[0].ID) - assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) - assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) - } + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) } -func TestSessionIDTracking(t *testing.T) { +func TestAWSBedrockIntegration(t *testing.T) { t.Parallel() - testCases := []struct { - name string - fixture []byte - expectedClient aibridge.Client - sessionID string - header http.Header - mutateBody func(t *testing.T, body []byte) []byte - }{ - // Session in header. - { - name: "mux", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientMux, - sessionID: "mux-workspace-321", - header: http.Header{ - "User-Agent": []string{"mux/1.0.0"}, - "X-Mux-Workspace-Id": []string{"mux-workspace-321"}, - }, - }, - // Session in body. - { - name: "claude_code", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientClaudeCode, - sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", - header: http.Header{ - "User-Agent": []string{"claude-cli/2.0.67 (external, cli)"}, - }, - mutateBody: func(t *testing.T, body []byte) []byte { - t.Helper() - // Claude Code embeds the session ID in metadata.user_id within the body. - body, err := sjson.SetBytes(body, "metadata.user_id", - "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479") - require.NoError(t, err) - return body - }, - }, - // No session. - { - name: "zed", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientZed, - header: http.Header{ - "User-Agent": []string{"Zed/0.219.4+stable.119.abc123 (macos; aarch64)"}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic)) + t.Run("invalid config", func(t *testing.T) { + t.Parallel() - reqBody := fix.Request() - if tc.mutateBody != nil { - reqBody = tc.mutateBody(t, reqBody) - } + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() + // Invalid bedrock config - missing region & base url + bedrockCfg := &config.AWSBedrock{ + Region: "", + AccessKey: "test-key", + AccessKeySecret: "test-secret", + Model: "test-model", + SmallFastModel: "test-haiku", + } - // Drain the body to let the stream complete. - _, err := io.ReadAll(resp.Body) - require.NoError(t, err) + bridgeServer := newBridgeTestServer(t, ctx, "http://unused", + withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), + withLogger(newLogger(t)), + ) - interceptions := bridgeServer.Recorder.RecordedInterceptions() - require.Len(t, interceptions, 1, "expected exactly one interception") - assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + defer resp.Body.Close() - if tc.sessionID == "" { - assert.Nil(t, interceptions[0].ClientSessionID, "expected nil session ID for %s", tc.name) - } else { - require.NotNil(t, interceptions[0].ClientSessionID, "expected non-nil session ID for %s", tc.name) - assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID) - } + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "create anthropic client") + require.Contains(t, string(body), "region or base url required") + }) - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) - }) - } -} + t.Run("/v1/messages", func(t *testing.T) { + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { + t.Parallel() -func TestFallthrough(t *testing.T) { - t.Parallel() + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - testCases := []struct { - name string - providerName string - fixture []byte - basePath string - requestPath string - expectedUpstreamPath string - authHeader string - }{ - { - name: "ant_empty_base_url_path", - providerName: config.ProviderAnthropic, - fixture: fixtures.AntFallthrough, - basePath: "", - requestPath: "/anthropic/v1/models", - expectedUpstreamPath: "/v1/models", - authHeader: "X-Api-Key", - }, - { - name: "oai_empty_base_url_path", - providerName: config.ProviderOpenAI, - fixture: fixtures.OaiChatFallthrough, - basePath: "", - requestPath: "/openai/v1/models", - expectedUpstreamPath: "/models", - authHeader: "Authorization", - }, - { - name: "ant_some_base_url_path", - providerName: config.ProviderAnthropic, - fixture: fixtures.AntFallthrough, - basePath: "/api", - requestPath: "/anthropic/v1/models", - expectedUpstreamPath: "/api/v1/models", - authHeader: "X-Api-Key", - }, - { - name: "oai_some_base_url_path", - providerName: config.ProviderOpenAI, - fixture: fixtures.OaiChatFallthrough, - basePath: "/api", - requestPath: "/openai/v1/models", - expectedUpstreamPath: "/api/models", - authHeader: "Authorization", - }, - } + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. + bedrockCfg := &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "danthropic", // This model should override the request's given one. + SmallFastModel: "danthropic-mini", // Unused but needed for validation. + BaseURL: upstream.URL, // Use the mock server. + } - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), + withLogger(newLogger(t)), + ) - resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) - defer resp.Body.Close() + // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. + // We override the AWS Bedrock client to route requests through our mock server. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + defer resp.Body.Close() - require.Equal(t, http.StatusOK, resp.StatusCode) + // For streaming responses, consume the body to allow the stream to complete. + if streaming { + // Read the streaming response. + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + } - // Verify upstream received the request at the expected path - // with the API key header. - received := upstream.receivedRequests() - require.Len(t, received, 1) - require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(tc.authHeader), apiKey) + // Verify that Bedrock-specific model name was used in the request to the mock server + // and the interception data. + received := upstream.receivedRequests() + require.Len(t, received, 1) - gotBytes, err := io.ReadAll(resp.Body) - require.NoError(t, err) + // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" + // from the JSON body and encodes them in the URL path. + // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 + pathParts := strings.Split(received[0].Path, "/") + require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) + require.Equal(t, bedrockCfg.Model, pathParts[2]) + require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") + require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - // Compare JSON bodies for semantic equality. - var got any - var exp any - require.NoError(t, json.Unmarshal(gotBytes, &got)) - require.NoError(t, json.Unmarshal(fix.NonStreaming(), &exp)) - require.EqualValues(t, exp, got) - }) - } + interceptions := bridgeServer.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + require.Equal(t, interceptions[0].Model, bedrockCfg.Model) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) } -func TestErrorHandling(t *testing.T) { +func TestOpenAIChatCompletions(t *testing.T) { t.Parallel() - // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected. - t.Run("non-stream error", func(t *testing.T) { + t.Run("single builtin tool", func(t *testing.T) { + t.Parallel() + cases := []struct { - name string - fixture []byte - path string - responseHandlerFn func(resp *http.Response) + streaming bool + expectedInputTokens, expectedOutputTokens int + expectedToolCallID string }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntNonStreamError, - path: pathAnthropicMessages, - responseHandlerFn: func(resp *http.Response) { - require.Equal(t, http.StatusBadRequest, resp.StatusCode) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "error", gjson.GetBytes(body, "type").Str) - require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) - require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long") - }, + streaming: true, + expectedInputTokens: 60, + expectedOutputTokens: 15, + expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatNonStreamError, - path: pathOpenAIChatCompletions, - responseHandlerFn: func(resp *http.Response) { - require.Equal(t, http.StatusBadRequest, resp.StatusCode) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "context_length_exceeded", gjson.GetBytes(body, "error.code").Str) - require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) - require.Contains(t, gjson.GetBytes(body, "error.message").Str, "Input tokens exceed the configured limit") - }, + streaming: false, + expectedInputTokens: 60, + expectedOutputTokens: 15, + expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", }, } for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { + t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - // Setup mock server. Error fixtures contain raw HTTP - // responses that may cause the bridge to retry. - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + // Make API call to aibridge for OpenAI /v1/chat/completions + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) + // Response-specific checks. + if tc.streaming { + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) - t.Cleanup(func() { _ = resp.Body.Close() }) + // OpenAI sends all events under the same type. + messageEvents := sp.MessageEvents() + assert.NotEmpty(t, messageEvents) + + // OpenAI streaming ends with [DONE] + lastEvent := messageEvents[len(messageEvents)-1] + assert.Equal(t, "[DONE]", lastEvent.Data) + } + + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() + require.Len(t, tokenUsages, 1) + assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") + + toolUsages := bridgeServer.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, "read_file", toolUsages[0].Tool) + assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) + require.IsType(t, map[string]any{}, toolUsages[0].Args) + require.Contains(t, toolUsages[0].Args, "path") + assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) + + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - tc.responseHandlerFn(resp) - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) - }) - } + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) - // Tests that errors which occur *during* a streaming response are handled as expected. - t.Run("mid-stream error", func(t *testing.T) { + t.Run("streaming injected tool call edge cases", func(t *testing.T) { + t.Parallel() + cases := []struct { - name string - fixture []byte - path string - responseHandlerFn func(resp *http.Response) + name string + fixture []byte + expectedArgs map[string]any }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntMidStreamError, - path: pathAnthropicMessages, - responseHandlerFn: func(resp *http.Response) { - // Server responds first with 200 OK then starts streaming. - require.Equal(t, http.StatusOK, resp.StatusCode) - - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - require.Len(t, sp.EventsByType("error"), 1) - require.Contains(t, sp.EventsByType("error")[0].Data, "Overloaded") - }, + name: "tool call no preamble", + fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, + expectedArgs: map[string]any{"owner": "me"}, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatMidStreamError, - path: pathOpenAIChatCompletions, - responseHandlerFn: func(resp *http.Response) { - // Server responds first with 200 OK then starts streaming. - require.Equal(t, http.StatusOK, resp.StatusCode) - - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - // OpenAI sends all events under the same type. - messageEvents := sp.MessageEvents() - require.NotEmpty(t, messageEvents) - - errEvent := sp.MessageEvents()[len(sp.MessageEvents())-2] // Last event is termination marker ("[DONE]"). - require.NotEmpty(t, errEvent) - require.Contains(t, errEvent.Data, "The server had an error while processing your request. Sorry about that!") - }, + name: "tool call with non-zero index", + fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, + expectedArgs: nil, // No arguments in this fixture }, } @@ -517,372 +326,414 @@ func TestErrorHandling(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Setup mock server. + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - upstream.StatusCode = http.StatusInternalServerError + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + // Setup MCP proxies with the tool from the fixture + mockMCP := setupMCPForTest(t, defaultTracer) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) - t.Cleanup(func() { _ = resp.Body.Close() }) - bridgeServer.Close() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mockMCP), + ) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) + require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify SSE headers are sent correctly + require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + require.Equal(t, "keep-alive", resp.Header.Get("Connection")) + + // Consume the full response body to ensure the interception completes + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + + // Verify the MCP tool was actually invoked + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1, "expected MCP tool to be invoked") + + // Verify tool was invoked with the expected args (if specified) + if tc.expectedArgs != nil { + expected, err := json.Marshal(tc.expectedArgs) + require.NoError(t, err) + actual, err := json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + } + + // Verify tool usage was recorded + toolUsages := bridgeServer.Recorder.RecordedToolUsages() + require.Len(t, toolUsages, 1) + assert.Equal(t, mockToolName, toolUsages[0].Tool) - tc.responseHandlerFn(resp) bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) } -// TestStableRequestEncoding validates that a given intercepted request and a -// given set of injected tools should result identical payloads. -// -// Should the payload vary, it may subvert any caching mechanisms the provider may have. -func TestStableRequestEncoding(t *testing.T) { +func TestSimple(t *testing.T) { t.Parallel() - cases := []struct { - name string - fixture []byte - path string - }{ - { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - path: pathAnthropicMessages, - }, - { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - path: pathOpenAIChatCompletions, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + getAnthropicResponseID := func(streaming bool, resp *http.Response) (string, error) { + if streaming { + decoder := ssestream.NewDecoder(resp) + stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil) + var message anthropic.Message + for stream.Next() { + event := stream.Current() + if err := message.Accumulate(event); err != nil { + return "", fmt.Errorf("accumulate event: %w", err) + } + } + if stream.Err() != nil { + return "", fmt.Errorf("stream error: %w", stream.Err()) + } + return message.ID, nil + } - // Setup MCP tools. - mockMCP := setupMCPForTest(t, defaultTracer) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read body: %w", err) + } - fix := fixtures.Parse(t, tc.fixture) + var message anthropic.Message + if err := json.Unmarshal(body, &message); err != nil { + return "", fmt.Errorf("unmarshal response: %w", err) + } + return message.ID, nil + } - // Create a mock upstream that serves the same blocking response for each request. - count := 10 - responses := make([]upstreamResponse, count) - for i := range count { - responses[i] = newFixtureResponse(fix) + getOpenAIResponseID := func(streaming bool, resp *http.Response) (string, error) { + if streaming { + // Parse the response stream. + decoder := oaissestream.NewDecoder(resp) + stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) + var message openai.ChatCompletionAccumulator + for stream.Next() { + chunk := stream.Current() + message.AddChunk(chunk) } - upstream := newMockUpstream(t, ctx, responses...) - - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, - withMCP(mockMCP), - ) - - // Make multiple requests and verify they all have identical payloads. - for range count { - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) - require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() + if stream.Err() != nil { + return "", fmt.Errorf("stream error: %w", stream.Err()) } + return message.ID, nil + } - // All upstream request bodies should be identical. - received := upstream.receivedRequests() - require.Len(t, received, count) - reference := string(received[0].Body) - for _, r := range received[1:] { - assert.JSONEq(t, reference, string(r.Body)) - } - }) - } -} + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read body: %w", err) + } -func TestEnvironmentDoNotLeak(t *testing.T) { - // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution. + var message openai.ChatCompletion + if err := json.Unmarshal(body, &message); err != nil { + return "", fmt.Errorf("unmarshal response: %w", err) + } + return message.ID, nil + } - // Test that environment variables containing API keys/tokens are not leaked to upstream requests. - // See https://github.com/coder/aibridge/issues/60. testCases := []struct { - name string - fixture []byte - path string - envVars map[string]string - headerName string + name string + fixture []byte + basePath string + expectedPath string + getResponseIDFunc func(streaming bool, resp *http.Response) (string, error) + path string + expectedMsgID string + userAgent string + expectedClient aibridge.Client }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - path: pathAnthropicMessages, - envVars: map[string]string{ - "ANTHROPIC_AUTH_TOKEN": "should-not-leak", - }, - headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present. + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + basePath: "", + expectedPath: "/v1/messages", + getResponseIDFunc: getAnthropicResponseID, + path: pathAnthropicMessages, + expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", + userAgent: "claude-cli/2.0.67 (external, cli)", + expectedClient: aibridge.ClientClaudeCode, + }, + { + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + basePath: "", + expectedPath: "/chat/completions", + getResponseIDFunc: getOpenAIResponseID, + path: pathOpenAIChatCompletions, + expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", + userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", + expectedClient: aibridge.ClientCodex, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - path: pathOpenAIChatCompletions, - envVars: map[string]string{ - "OPENAI_ORG_ID": "should-not-leak", - }, - headerName: "OpenAI-Organization", + name: config.ProviderAnthropic + "_baseURL_path", + fixture: fixtures.AntSimple, + basePath: "/api", + expectedPath: "/api/v1/messages", + getResponseIDFunc: getAnthropicResponseID, + path: pathAnthropicMessages, + expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", + userAgent: "GitHubCopilotChat/0.37.2026011603", + expectedClient: aibridge.ClientCopilotVSC, + }, + { + name: config.ProviderOpenAI + "_baseURL_path", + fixture: fixtures.OaiChatSimple, + basePath: "/api", + expectedPath: "/api/chat/completions", + getResponseIDFunc: getOpenAIResponseID, + path: pathOpenAIChatCompletions, + expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", + userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", + expectedClient: aibridge.ClientZed, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. + t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - // Set environment variables that the SDK would automatically read. - // These should NOT leak into upstream requests. - for key, val := range tc.envVars { - t.Setenv(key, val) - } + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() + // When: calling the "API server" with the fixture's request body. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() - // Verify that environment values did not leak. - received := upstream.receivedRequests() - require.Len(t, received, 1) - require.Empty(t, received[0].Header.Get(tc.headerName)) + // Then: I expect the upstream request to have the correct path. + received := upstream.receivedRequests() + require.Len(t, received, 1) + require.Equal(t, tc.expectedPath, received[0].Path) + + // Then: I expect a non-empty response. + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.NotEmpty(t, bodyBytes, "should have received response body") + + // Reset the body after being read. + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Then: I expect the prompt to have been tracked. + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() + require.NotEmpty(t, promptUsages, "no prompts tracked") + assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") + + // Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider. + // The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting + // multiple messages in response to a single request. + id, err := tc.getResponseIDFunc(streaming, resp) + require.NoError(t, err, "failed to retrieve response ID") + require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) + + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() + require.GreaterOrEqual(t, len(tokenUsages), 1) + require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) + + // Validate user agent and client have been recorded. + interceptions := bridgeServer.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) + assert.Equal(t, id, interceptions[0].ID) + assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) + assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) + + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + }) + } }) } } -func TestActorHeaders(t *testing.T) { +func TestSessionIDTracking(t *testing.T) { t.Parallel() - actorUsername := "bob" - - cases := []struct { - name string - path string - createProviderFn func(url, key string, sendHeaders bool) aibridge.Provider - fixture []byte - streaming bool + testCases := []struct { + name string + fixture []byte + expectedClient aibridge.Client + sessionID string + header http.Header + mutateBody func(t *testing.T, body []byte) []byte }{ + // Session in header. { - name: "openai/v1/chat/completions", - path: pathOpenAIChatCompletions, - createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openAICfg(url, key) - cfg.SendActorHeaders = sendHeaders - return provider.NewOpenAI(cfg) - }, - fixture: fixtures.OaiChatSimple, - streaming: true, - }, - { - name: "openai/v1/chat/completions", - path: pathOpenAIChatCompletions, - createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openAICfg(url, key) - cfg.SendActorHeaders = sendHeaders - return provider.NewOpenAI(cfg) - }, - fixture: fixtures.OaiChatSimple, - streaming: false, - }, - { - name: "openai/v1/responses", - path: pathOpenAIResponses, - createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openAICfg(url, key) - cfg.SendActorHeaders = sendHeaders - return provider.NewOpenAI(cfg) + name: "mux", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientMux, + sessionID: "mux-workspace-321", + header: http.Header{ + "User-Agent": []string{"mux/1.0.0"}, + "X-Mux-Workspace-Id": []string{"mux-workspace-321"}, }, - fixture: fixtures.OaiResponsesStreamingSimple, - streaming: true, }, + // Session in body. { - name: "openai/v1/responses", - path: pathOpenAIResponses, - createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openAICfg(url, key) - cfg.SendActorHeaders = sendHeaders - return provider.NewOpenAI(cfg) + name: "claude_code", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientClaudeCode, + sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", + header: http.Header{ + "User-Agent": []string{"claude-cli/2.0.67 (external, cli)"}, }, - fixture: fixtures.OaiResponsesBlockingSimple, - streaming: false, - }, - { - name: "anthropic/v1/messages", - path: pathAnthropicMessages, - createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := anthropicCfg(url, key) - cfg.SendActorHeaders = sendHeaders - return provider.NewAnthropic(cfg, nil) + mutateBody: func(t *testing.T, body []byte) []byte { + t.Helper() + // Claude Code embeds the session ID in metadata.user_id within the body. + body, err := sjson.SetBytes(body, "metadata.user_id", + "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479") + require.NoError(t, err) + return body }, - fixture: fixtures.AntSimple, - streaming: true, }, + // No session. { - name: "anthropic/v1/messages", - path: pathAnthropicMessages, - createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := anthropicCfg(url, key) - cfg.SendActorHeaders = sendHeaders - return provider.NewAnthropic(cfg, nil) + name: "zed", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientZed, + header: http.Header{ + "User-Agent": []string{"Zed/0.219.4+stable.119.abc123 (macos; aarch64)"}, }, - fixture: fixtures.AntSimple, - streaming: false, }, } - for _, tc := range cases { - for _, send := range []bool{true, false} { - t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Track headers received by the upstream server. - var receivedHeaders http.Header - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.WriteHeader(http.StatusTeapot) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) - - metadataKey := "Username" - bridgeServer := newBridgeTestServer(t, ctx, srv.URL, - withCustomProvider(tc.createProviderFn(srv.URL, apiKey, send)), - withActor(defaultActorID, recorder.Metadata{ - metadataKey: actorUsername, - }), - ) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) - require.NoError(t, err) - - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) - require.NotEmpty(t, receivedHeaders) - defer resp.Body.Close() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - // Verify that the actor headers were only received if intended. - found := make(map[string][]string) - for k, v := range receivedHeaders { - k = strings.ToLower(k) - if intercept.IsActorHeader(k) { - found[k] = v - } - } + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - if send { - require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{defaultActorID}) - require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) - } else { - require.Empty(t, found) - } - }) - } - } -} + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic)) -func TestAnthropicMessages(t *testing.T) { - t.Parallel() + reqBody := fix.Request() + if tc.mutateBody != nil { + reqBody = tc.mutateBody(t, reqBody) + } - t.Run("single builtin tool", func(t *testing.T) { - t.Parallel() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() - cases := []struct { - streaming bool - expectedInputTokens int - expectedOutputTokens int - expectedToolCallID string - }{ - { - streaming: true, - expectedInputTokens: 2, - expectedOutputTokens: 66, - expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", - }, - { - streaming: false, - expectedInputTokens: 5, - expectedOutputTokens: 84, - expectedToolCallID: "toolu_01AusGgY5aKFhzWrFBv9JfHq", - }, - } + // Drain the body to let the stream complete. + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) - for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - t.Parallel() + interceptions := bridgeServer.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1, "expected exactly one interception") + assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + if tc.sessionID == "" { + assert.Nil(t, interceptions[0].ClientSessionID, "expected nil session ID for %s", tc.name) + } else { + require.NotNil(t, interceptions[0].ClientSessionID, "expected non-nil session ID for %s", tc.name) + assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID) + } - fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + }) + } +} - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) +func TestFallthrough(t *testing.T) { + t.Parallel() - // Make API call to aibridge for Anthropic /v1/messages - reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) - require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() + testCases := []struct { + name string + providerName string + fixture []byte + basePath string + requestPath string + expectedUpstreamPath string + authHeader string + }{ + { + name: "ant_empty_base_url_path", + providerName: config.ProviderAnthropic, + fixture: fixtures.AntFallthrough, + basePath: "", + requestPath: "/anthropic/v1/models", + expectedUpstreamPath: "/v1/models", + authHeader: "X-Api-Key", + }, + { + name: "oai_empty_base_url_path", + providerName: config.ProviderOpenAI, + fixture: fixtures.OaiChatFallthrough, + basePath: "", + requestPath: "/openai/v1/models", + expectedUpstreamPath: "/models", + authHeader: "Authorization", + }, + { + name: "ant_some_base_url_path", + providerName: config.ProviderAnthropic, + fixture: fixtures.AntFallthrough, + basePath: "/api", + requestPath: "/anthropic/v1/models", + expectedUpstreamPath: "/api/v1/models", + authHeader: "X-Api-Key", + }, + { + name: "oai_some_base_url_path", + providerName: config.ProviderOpenAI, + fixture: fixtures.OaiChatFallthrough, + basePath: "/api", + requestPath: "/openai/v1/models", + expectedUpstreamPath: "/api/models", + authHeader: "Authorization", + }, + } - // Response-specific checks. - if tc.streaming { - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - // Ensure the message starts and completes, at a minimum. - assert.Contains(t, sp.AllEvents(), "message_start") - assert.Contains(t, sp.AllEvents(), "message_stop") - } + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) - expectedTokenRecordings := 1 - if tc.streaming { - // One for message_start, one for message_delta. - expectedTokenRecordings = 2 - } - tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() - require.Len(t, tokenUsages, expectedTokenRecordings) + resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) + defer resp.Body.Close() - assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") + require.Equal(t, http.StatusOK, resp.StatusCode) - toolUsages := bridgeServer.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, "Read", toolUsages[0].Tool) - assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) - require.IsType(t, json.RawMessage{}, toolUsages[0].Args) - var args map[string]any - require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args)) - require.Contains(t, args, "file_path") - assert.Equal(t, "/tmp/blah/foo", args["file_path"]) + // Verify upstream received the request at the expected path + // with the API key header. + received := upstream.receivedRequests() + require.Len(t, received, 1) + require.Equal(t, tc.expectedUpstreamPath, received[0].Path) + require.Contains(t, received[0].Header.Get(tc.authHeader), apiKey) - promptUsages := bridgeServer.Recorder.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - assert.Equal(t, "read the foo file", promptUsages[0].Prompt) + gotBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) + // Compare JSON bodies for semantic equality. + var got any + var exp any + require.NoError(t, json.Unmarshal(gotBytes, &got)) + require.NoError(t, json.Unmarshal(fix.NonStreaming(), &exp)) + require.EqualValues(t, exp, got) + }) + } } func TestAnthropicInjectedTools(t *testing.T) { @@ -958,8 +809,110 @@ func TestAnthropicInjectedTools(t *testing.T) { assert.EqualValues(t, 204, message.Usage.OutputTokens) // Ensure tokens used during injected tool invocation are accounted for. - assert.EqualValues(t, 15308, recorderClient.TotalInputTokens()) - assert.EqualValues(t, 204, recorderClient.TotalOutputTokens()) + assert.EqualValues(t, 15308, recorderClient.TotalInputTokens()) + assert.EqualValues(t, 204, recorderClient.TotalOutputTokens()) + + // Ensure we received exactly one prompt. + promptUsages := recorderClient.RecordedPromptUsages() + require.Len(t, promptUsages, 1) + }) + } +} + +// anthropicToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_use and user's tool_result messages +// appended by the inner agentic loop. If the raw payload is not kept in sync with +// the structured messages, the second request will be identical to the first. +func TestOpenAIInjectedTools(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + // Build the requirements & make the assertions which are common to all providers. + recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) + + // Ensure expected tool was invoked with expected input. + toolUsages := recorderClient.RecordedToolUsages() + require.Len(t, toolUsages, 1) + require.Equal(t, mockToolName, toolUsages[0].Tool) + expected, err := json.Marshal(map[string]any{"owner": "admin"}) + require.NoError(t, err) + actual, err := json.Marshal(toolUsages[0].Args) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + invocations := mockMCP.getCallsByTool(mockToolName) + require.Len(t, invocations, 1) + actual, err = json.Marshal(invocations[0]) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + + var ( + content *openai.ChatCompletionChoice + message openai.ChatCompletion + ) + if streaming { + // Parse the response stream. + decoder := oaissestream.NewDecoder(resp) + stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) + var acc openai.ChatCompletionAccumulator + detectedToolCalls := make(map[string]struct{}) + for stream.Next() { + chunk := stream.Current() + acc.AddChunk(chunk) + + if len(chunk.Choices) == 0 { + continue + } + + for _, c := range chunk.Choices { + if len(c.Delta.ToolCalls) == 0 { + continue + } + + for _, t := range c.Delta.ToolCalls { + if t.Function.Name == "" { + continue + } + + detectedToolCalls[t.Function.Name] = struct{}{} + } + } + } + + // Verify that no injected tool call events (or partials thereof) were sent to the client. + require.Len(t, detectedToolCalls, 0) + + message = acc.ChatCompletion + require.NoError(t, stream.Err(), "stream error") + } else { + // Parse & unmarshal the response. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "read response body") + require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") + + // Verify that no injected tools were sent to the client. + require.GreaterOrEqual(t, len(message.Choices), 1) + require.Len(t, message.Choices[0].Message.ToolCalls, 0) + } + + require.GreaterOrEqual(t, len(message.Choices), 1) + content = &message.Choices[0] + + // Ensure tool returned expected value. + require.NotNil(t, content) + require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. + + // Check the token usage from the client's perspective. + // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. + // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. + // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) + assert.EqualValues(t, 105, message.Usage.CompletionTokens) + + // Ensure tokens used during injected tool invocation are accounted for. + require.EqualValues(t, 5047, recorderClient.TotalInputTokens()) + require.EqualValues(t, 105, recorderClient.TotalOutputTokens()) // Ensure we received exactly one prompt. promptUsages := recorderClient.RecordedPromptUsages() @@ -968,10 +921,9 @@ func TestAnthropicInjectedTools(t *testing.T) { } } -// anthropicToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_use and user's tool_result messages -// appended by the inner agentic loop. If the raw payload is not kept in sync with -// the structured messages, the second request will be identical to the first. +// openaiChatToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_calls and a role=tool result message +// appended by the inner agentic loop. func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { t.Helper() @@ -1014,6 +966,239 @@ func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { // TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is // correctly disabled based on the tool_choice parameter in the request. // See https://github.com/coder/aibridge/issues/2 +func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { + t.Helper() + + return func(_ *http.Request, raw []byte) { + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_calls array + // [N-1] message with role=tool + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_calls, and tool result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), + "assistant message must contain a tool_calls array") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "tool", toolResultMsg.Get("role").Str, + "last message must have role=tool") + require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, + "tool result message must have a tool_call_id") + } +} + +func TestErrorHandling(t *testing.T) { + t.Parallel() + + // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected. + t.Run("non-stream error", func(t *testing.T) { + cases := []struct { + name string + fixture []byte + path string + responseHandlerFn func(resp *http.Response) + }{ + { + name: config.ProviderAnthropic, + fixture: fixtures.AntNonStreamError, + path: pathAnthropicMessages, + responseHandlerFn: func(resp *http.Response) { + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "error", gjson.GetBytes(body, "type").Str) + require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) + require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long") + }, + }, + { + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatNonStreamError, + path: pathOpenAIChatCompletions, + responseHandlerFn: func(resp *http.Response) { + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "context_length_exceeded", gjson.GetBytes(body, "error.code").Str) + require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) + require.Contains(t, gjson.GetBytes(body, "error.message").Str, "Input tokens exceed the configured limit") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup mock server. Error fixtures contain raw HTTP + // responses that may cause the bridge to retry. + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + t.Cleanup(func() { _ = resp.Body.Close() }) + + tc.responseHandlerFn(resp) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) + } + }) + + // Tests that errors which occur *during* a streaming response are handled as expected. + t.Run("mid-stream error", func(t *testing.T) { + cases := []struct { + name string + fixture []byte + path string + responseHandlerFn func(resp *http.Response) + }{ + { + name: config.ProviderAnthropic, + fixture: fixtures.AntMidStreamError, + path: pathAnthropicMessages, + responseHandlerFn: func(resp *http.Response) { + // Server responds first with 200 OK then starts streaming. + require.Equal(t, http.StatusOK, resp.StatusCode) + + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) + require.Len(t, sp.EventsByType("error"), 1) + require.Contains(t, sp.EventsByType("error")[0].Data, "Overloaded") + }, + }, + { + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatMidStreamError, + path: pathOpenAIChatCompletions, + responseHandlerFn: func(resp *http.Response) { + // Server responds first with 200 OK then starts streaming. + require.Equal(t, http.StatusOK, resp.StatusCode) + + sp := aibridge.NewSSEParser() + require.NoError(t, sp.Parse(resp.Body)) + // OpenAI sends all events under the same type. + messageEvents := sp.MessageEvents() + require.NotEmpty(t, messageEvents) + + errEvent := sp.MessageEvents()[len(sp.MessageEvents())-2] // Last event is termination marker ("[DONE]"). + require.NotEmpty(t, errEvent) + require.Contains(t, errEvent.Data, "The server had an error while processing your request. Sorry about that!") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup mock server. + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream.StatusCode = http.StatusInternalServerError + + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + t.Cleanup(func() { _ = resp.Body.Close() }) + bridgeServer.Close() + + tc.responseHandlerFn(resp) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + }) + } + }) +} + +// TestStableRequestEncoding validates that a given intercepted request and a +// given set of injected tools should result identical payloads. +// +// Should the payload vary, it may subvert any caching mechanisms the provider may have. +func TestStableRequestEncoding(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + fixture []byte + path string + }{ + { + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + path: pathAnthropicMessages, + }, + { + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + path: pathOpenAIChatCompletions, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + // Setup MCP tools. + mockMCP := setupMCPForTest(t, defaultTracer) + + fix := fixtures.Parse(t, tc.fixture) + + // Create a mock upstream that serves the same blocking response for each request. + count := 10 + responses := make([]upstreamResponse, count) + for i := range count { + responses[i] = newFixtureResponse(fix) + } + upstream := newMockUpstream(t, ctx, responses...) + + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mockMCP), + ) + + // Make multiple requests and verify they all have identical payloads. + for range count { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + require.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + } + + // All upstream request bodies should be identical. + received := upstream.receivedRequests() + require.Len(t, received, count) + reference := string(received[0].Body) + for _, r := range received[1:] { + assert.JSONEq(t, reference, string(r.Body)) + } + }) + } +} + func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Parallel() @@ -1180,386 +1365,200 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { } } -func TestOpenAIChatCompletions(t *testing.T) { - t.Parallel() - - t.Run("single builtin tool", func(t *testing.T) { - t.Parallel() - - cases := []struct { - streaming bool - expectedInputTokens, expectedOutputTokens int - expectedToolCallID string - }{ - { - streaming: true, - expectedInputTokens: 60, - expectedOutputTokens: 15, - expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", - }, - { - streaming: false, - expectedInputTokens: 60, - expectedOutputTokens: 15, - expectedToolCallID: "call_KjzAbhiZC6nk81tQzL7pwlpc", - }, - } - - for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - - // Make API call to aibridge for OpenAI /v1/chat/completions - reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) - require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) - require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - - // Response-specific checks. - if tc.streaming { - sp := aibridge.NewSSEParser() - require.NoError(t, sp.Parse(resp.Body)) - - // OpenAI sends all events under the same type. - messageEvents := sp.MessageEvents() - assert.NotEmpty(t, messageEvents) - - // OpenAI streaming ends with [DONE] - lastEvent := messageEvents[len(messageEvents)-1] - assert.Equal(t, "[DONE]", lastEvent.Data) - } - - tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() - require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") - - toolUsages := bridgeServer.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, "read_file", toolUsages[0].Tool) - assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) - require.IsType(t, map[string]any{}, toolUsages[0].Args) - require.Contains(t, toolUsages[0].Args, "path") - assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - - promptUsages := bridgeServer.Recorder.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) - - t.Run("streaming injected tool call edge cases", func(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - fixture []byte - expectedArgs map[string]any - }{ - { - name: "tool call no preamble", - fixture: fixtures.OaiChatStreamingInjectedToolNoPreamble, - expectedArgs: map[string]any{"owner": "me"}, - }, - { - name: "tool call with non-zero index", - fixture: fixtures.OaiChatStreamingInjectedToolNonzeroIndex, - expectedArgs: nil, // No arguments in this fixture - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. - fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) - - // Setup MCP proxies with the tool from the fixture - mockMCP := setupMCPForTest(t, defaultTracer) - - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, - withMCP(mockMCP), - ) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) - require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) - require.Equal(t, http.StatusOK, resp.StatusCode) - - // Verify SSE headers are sent correctly - require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) - require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) - require.Equal(t, "keep-alive", resp.Header.Get("Connection")) - - // Consume the full response body to ensure the interception completes - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() - - // Verify the MCP tool was actually invoked - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1, "expected MCP tool to be invoked") - - // Verify tool was invoked with the expected args (if specified) - if tc.expectedArgs != nil { - expected, err := json.Marshal(tc.expectedArgs) - require.NoError(t, err) - actual, err := json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - } - - // Verify tool usage was recorded - toolUsages := bridgeServer.Recorder.RecordedToolUsages() - require.Len(t, toolUsages, 1) - assert.Equal(t, mockToolName, toolUsages[0].Tool) - - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) - }) - } - }) -} - -func TestOpenAIInjectedTools(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() - - // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) - - // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() - require.Len(t, toolUsages, 1) - require.Equal(t, mockToolName, toolUsages[0].Tool) - expected, err := json.Marshal(map[string]any{"owner": "admin"}) - require.NoError(t, err) - actual, err := json.Marshal(toolUsages[0].Args) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - invocations := mockMCP.getCallsByTool(mockToolName) - require.Len(t, invocations, 1) - actual, err = json.Marshal(invocations[0]) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - var ( - content *openai.ChatCompletionChoice - message openai.ChatCompletion - ) - if streaming { - // Parse the response stream. - decoder := oaissestream.NewDecoder(resp) - stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil) - var acc openai.ChatCompletionAccumulator - detectedToolCalls := make(map[string]struct{}) - for stream.Next() { - chunk := stream.Current() - acc.AddChunk(chunk) - - if len(chunk.Choices) == 0 { - continue - } - - for _, c := range chunk.Choices { - if len(c.Delta.ToolCalls) == 0 { - continue - } - - for _, t := range c.Delta.ToolCalls { - if t.Function.Name == "" { - continue - } - - detectedToolCalls[t.Function.Name] = struct{}{} - } - } - } - - // Verify that no injected tool call events (or partials thereof) were sent to the client. - require.Len(t, detectedToolCalls, 0) - - message = acc.ChatCompletion - require.NoError(t, stream.Err(), "stream error") - } else { - // Parse & unmarshal the response. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "read response body") - require.NoError(t, json.Unmarshal(body, &message), "unmarshal response") - - // Verify that no injected tools were sent to the client. - require.GreaterOrEqual(t, len(message.Choices), 1) - require.Len(t, message.Choices[0].Message.ToolCalls, 0) - } - - require.GreaterOrEqual(t, len(message.Choices), 1) - content = &message.Choices[0] - - // Ensure tool returned expected value. - require.NotNil(t, content) - require.Contains(t, content.Message.Content, "dd711d5c-83c6-4c08-a0af-b73055906e8c") // The ID of the workspace to be returned. +func TestEnvironmentDoNotLeak(t *testing.T) { + // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution. - // Check the token usage from the client's perspective. - // This *should* work but the openai SDK doesn't accumulate the prompt token details :(. - // See https://github.com/openai/openai-go/blob/v2.7.0/streamaccumulator.go#L145-L147. - // assert.EqualValues(t, 5047, message.Usage.PromptTokens-message.Usage.PromptTokensDetails.CachedTokens) - assert.EqualValues(t, 105, message.Usage.CompletionTokens) + // Test that environment variables containing API keys/tokens are not leaked to upstream requests. + // See https://github.com/coder/aibridge/issues/60. + testCases := []struct { + name string + fixture []byte + path string + envVars map[string]string + headerName string + }{ + { + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + path: pathAnthropicMessages, + envVars: map[string]string{ + "ANTHROPIC_AUTH_TOKEN": "should-not-leak", + }, + headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present. + }, + { + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + path: pathOpenAIChatCompletions, + envVars: map[string]string{ + "OPENAI_ORG_ID": "should-not-leak", + }, + headerName: "OpenAI-Organization", + }, + } - // Ensure tokens used during injected tool invocation are accounted for. - require.EqualValues(t, 5047, recorderClient.TotalInputTokens()) - require.EqualValues(t, 105, recorderClient.TotalOutputTokens()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. - // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() - require.Len(t, promptUsages, 1) - }) - } -} + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) -// openaiChatToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_calls and a role=tool result message -// appended by the inner agentic loop. -func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { - t.Helper() + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - return func(_ *http.Request, raw []byte) { - messages := gjson.GetBytes(raw, "messages").Array() + // Set environment variables that the SDK would automatically read. + // These should NOT leak into upstream requests. + for key, val := range tc.envVars { + t.Setenv(key, val) + } - // After the agentic loop the messages must contain at minimum: - // [0] original user message - // [N-2] assistant message with tool_calls array - // [N-1] message with role=tool - require.GreaterOrEqual(t, len(messages), 3, - "second upstream request must contain the original message, assistant tool_calls, and tool result") + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - assistantMsg := messages[len(messages)-2] - require.Equal(t, "assistant", assistantMsg.Get("role").Str, - "penultimate message must be from the assistant") - require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), - "assistant message must contain a tool_calls array") + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() - toolResultMsg := messages[len(messages)-1] - require.Equal(t, "tool", toolResultMsg.Get("role").Str, - "last message must have role=tool") - require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, - "tool result message must have a tool_call_id") + // Verify that environment values did not leak. + received := upstream.receivedRequests() + require.Len(t, received, 1) + require.Empty(t, received[0].Header.Get(tc.headerName)) + }) } } -func TestAWSBedrockIntegration(t *testing.T) { +func TestActorHeaders(t *testing.T) { t.Parallel() - t.Run("invalid config", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - // Invalid bedrock config - missing region & base url - bedrockCfg := &config.AWSBedrock{ - Region: "", - AccessKey: "test-key", - AccessKeySecret: "test-secret", - Model: "test-model", - SmallFastModel: "test-haiku", - } - - bridgeServer := newBridgeTestServer(t, ctx, "http://unused", - withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), - withLogger(newLogger(t)), - ) - - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) - defer resp.Body.Close() + actorUsername := "bob" - require.Equal(t, http.StatusInternalServerError, resp.StatusCode) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "create anthropic client") - require.Contains(t, string(body), "region or base url required") - }) + cases := []struct { + name string + path string + createProviderFn func(url, key string, sendHeaders bool) aibridge.Provider + fixture []byte + streaming bool + }{ + { + name: "openai/v1/chat/completions", + path: pathOpenAIChatCompletions, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openAICfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiChatSimple, + streaming: true, + }, + { + name: "openai/v1/chat/completions", + path: pathOpenAIChatCompletions, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openAICfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiChatSimple, + streaming: false, + }, + { + name: "openai/v1/responses", + path: pathOpenAIResponses, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openAICfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiResponsesStreamingSimple, + streaming: true, + }, + { + name: "openai/v1/responses", + path: pathOpenAIResponses, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := openAICfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewOpenAI(cfg) + }, + fixture: fixtures.OaiResponsesBlockingSimple, + streaming: false, + }, + { + name: "anthropic/v1/messages", + path: pathAnthropicMessages, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := anthropicCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewAnthropic(cfg, nil) + }, + fixture: fixtures.AntSimple, + streaming: true, + }, + { + name: "anthropic/v1/messages", + path: pathAnthropicMessages, + createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { + cfg := anthropicCfg(url, key) + cfg.SendActorHeaders = sendHeaders + return provider.NewAnthropic(cfg, nil) + }, + fixture: fixtures.AntSimple, + streaming: false, + }, + } - t.Run("/v1/messages", func(t *testing.T) { - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { + for _, tc := range cases { + for _, send := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - - // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. - bedrockCfg := &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "danthropic", // This model should override the request's given one. - SmallFastModel: "danthropic-mini", // Unused but needed for validation. - BaseURL: upstream.URL, // Use the mock server. + // Track headers received by the upstream server. + var receivedHeaders http.Header + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusTeapot) + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx } + srv.Start() + t.Cleanup(srv.Close) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, - withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), - withLogger(newLogger(t)), + metadataKey := "Username" + bridgeServer := newBridgeTestServer(t, ctx, srv.URL, + withCustomProvider(tc.createProviderFn(srv.URL, apiKey, send)), + withActor(defaultActorID, recorder.Metadata{ + metadataKey: actorUsername, + }), ) - // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. - // We override the AWS Bedrock client to route requests through our mock server. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + require.NotEmpty(t, receivedHeaders) defer resp.Body.Close() - // For streaming responses, consume the body to allow the stream to complete. - if streaming { - // Read the streaming response. - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) + // Verify that the actor headers were only received if intended. + found := make(map[string][]string) + for k, v := range receivedHeaders { + k = strings.ToLower(k) + if intercept.IsActorHeader(k) { + found[k] = v + } } - // Verify that Bedrock-specific model name was used in the request to the mock server - // and the interception data. - received := upstream.receivedRequests() - require.Len(t, received, 1) - - // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" - // from the JSON body and encodes them in the URL path. - // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 - pathParts := strings.Split(received[0].Path, "/") - require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) - require.Equal(t, bedrockCfg.Model, pathParts[2]) - require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") - require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - - interceptions := bridgeServer.Recorder.RecordedInterceptions() - require.Len(t, interceptions, 1) - require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) + if send { + require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{defaultActorID}) + require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) + } else { + require.Empty(t, found) + } }) } - }) + } } - From 89a70b74427188e48d2971f77632889cd9d42f2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 14:46:22 +0000 Subject: [PATCH 21/32] remove not needed withLogger --- internal/integrationtest/bridge_test.go | 2 -- internal/integrationtest/setupbridge.go | 5 ----- 2 files changed, 7 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 16fb06e..f8cf448 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -143,7 +143,6 @@ func TestAWSBedrockIntegration(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, "http://unused", withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), - withLogger(newLogger(t)), ) resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) @@ -179,7 +178,6 @@ func TestAWSBedrockIntegration(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), - withLogger(newLogger(t)), ) // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 2e15a0d..ced939b 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -122,11 +122,6 @@ func withActor(id string, md recorder.Metadata) bridgeOption { return func(c *bridgeConfig) { c.userID = id; c.metadata = md } } -// withLogger overrides the default slogtest debug logger. -func withLogger(l slog.Logger) bridgeOption { - return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } -} - // newBridgeTestServer creates a fully configured test server running // a RequestBridge with sensible defaults: // - All standard providers (unless withProvider / withCustomProvider) From cfa535f537c2294ffbad4ad3cad43aa876f70981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 15:07:12 +0000 Subject: [PATCH 22/32] fix comments --- internal/integrationtest/bridge_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index f8cf448..213c3fe 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -817,10 +817,6 @@ func TestAnthropicInjectedTools(t *testing.T) { } } -// anthropicToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_use and user's tool_result messages -// appended by the inner agentic loop. If the raw payload is not kept in sync with -// the structured messages, the second request will be identical to the first. func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() @@ -919,9 +915,10 @@ func TestOpenAIInjectedTools(t *testing.T) { } } -// openaiChatToolResultValidator returns a request validator that asserts the second -// upstream request contains the assistant's tool_calls and a role=tool result message -// appended by the inner agentic loop. +// anthropicToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_use and user's tool_result messages +// appended by the inner agentic loop. If the raw payload is not kept in sync with +// the structured messages, the second request will be identical to the first. func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { t.Helper() @@ -961,9 +958,9 @@ func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { } } -// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is -// correctly disabled based on the tool_choice parameter in the request. -// See https://github.com/coder/aibridge/issues/2 +// openaiChatToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_calls and a role=tool result message +// appended by the inner agentic loop. func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { t.Helper() @@ -1197,6 +1194,9 @@ func TestStableRequestEncoding(t *testing.T) { } } +// TestAnthropicToolChoiceParallelDisabled verifies that parallel tool use is +// correctly disabled based on the tool_choice parameter in the request. +// See https://github.com/coder/aibridge/issues/2 func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Parallel() From 5bf7bc7767834c0177b5f50ab7efc3c689a89c88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 15:12:13 +0000 Subject: [PATCH 23/32] rename --- internal/integrationtest/bridge_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 213c3fe..ea8c1ac 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -1275,18 +1275,18 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools conditionally. - var mcpMgr mcp.ServerProxier + var mockMCP mcp.ServerProxier if tc.withInjectedTools { - mcpMgr = setupMCPForTest(t, defaultTracer) + mockMCP = setupMCPForTest(t, defaultTracer) } else { - mcpMgr = newNoopMCPManager() + mockMCP = newNoopMCPManager() } fix := fixtures.Parse(t, fixtures.AntSimple) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, - withMCP(mcpMgr), + withMCP(mockMCP), ) // Prepare request body with tool_choice set. From dfa21f097628bd50a3a2eb6a3ac36af7d74bb7e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 15:13:37 +0000 Subject: [PATCH 24/32] formatting fix --- .../integrationtest/circuit_breaker_test.go | 30 +++++++++---------- internal/integrationtest/trace_test.go | 2 +- internal/testutil/mock_recorder.go | 1 + 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 9c3ff55..0ca7f59 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -45,7 +45,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { successBody string requestBody string headers http.Header - path string + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -62,8 +62,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, headers: http.Header{ - "x-api-key": {"test"}, - "anthropic-version": {"2023-06-01"}, + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { @@ -82,8 +82,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody: openAIRateLimitError, successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - headers: http.Header{"Authorization": {"Bearer test-key"}}, - path: pathOpenAIChatCompletions, + headers: http.Header{"Authorization": {"Bearer test-key"}}, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -210,7 +210,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { errorBody string requestBody string headers http.Header - path string + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -226,8 +226,8 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { errorBody: anthropicRateLimitError, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, headers: http.Header{ - "x-api-key": {"test"}, - "anthropic-version": {"2023-06-01"}, + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { @@ -245,8 +245,8 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { expectModel: "gpt-4o", errorBody: openAIRateLimitError, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - headers: http.Header{"Authorization": {"Bearer test-key"}}, - path: pathOpenAIChatCompletions, + headers: http.Header{"Authorization": {"Bearer test-key"}}, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -345,7 +345,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { successBody string requestBody string headers http.Header - path string + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -362,8 +362,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, headers: http.Header{ - "x-api-key": {"test"}, - "anthropic-version": {"2023-06-01"}, + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { @@ -382,8 +382,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody: openAIRateLimitError, successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - headers: http.Header{"Authorization": {"Bearer test-key"}}, - path: pathOpenAIChatCompletions, + headers: http.Header{"Authorization": {"Bearer test-key"}}, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index d3e9542..7256ba9 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -18,9 +18,9 @@ import ( "github.com/tidwall/sjson" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - oteltrace "go.opentelemetry.io/otel/trace" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" + oteltrace "go.opentelemetry.io/otel/trace" ) // expect 'count' amount of traces named 'name' with status 'status' diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index a256a64..18d869a 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -73,6 +73,7 @@ func (m *MockRecorder) RecordedTokenUsages() []*recorder.TokenUsageRecord { defer m.mu.Unlock() return slices.Clone(m.tokenUsages) } + // TotalInputTokens returns the sum of input tokens across all recorded token usages. func (m *MockRecorder) TotalInputTokens() int64 { var total int64 From a482ec1bcb1822a16805e70d8542eaf0046d1bb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 15:31:15 +0000 Subject: [PATCH 25/32] reorder to previous --- internal/integrationtest/bridge_test.go | 1 - internal/integrationtest/mockmcp.go | 14 +++--- internal/integrationtest/mockupstream.go | 56 ++++++++++++------------ 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index ea8c1ac..b1991b3 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -1119,7 +1119,6 @@ func TestErrorHandling(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) t.Cleanup(func() { _ = resp.Body.Close() }) - bridgeServer.Close() tc.responseHandlerFn(resp) bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) diff --git a/internal/integrationtest/mockmcp.go b/internal/integrationtest/mockmcp.go index e2f3dcb..eba25dd 100644 --- a/internal/integrationtest/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -31,13 +31,6 @@ type mockMCP struct { calls *callAccumulator } -// callAccumulator tracks all tool invocations by name and each instance's arguments. -type callAccumulator struct { - calls map[string][]any - callsMu sync.Mutex - toolErrors map[string]string -} - // getCallsByTool returns recorded arguments for a given tool name. func (m *mockMCP) getCallsByTool(name string) []any { return m.calls.getCallsByTool(name) @@ -91,6 +84,13 @@ func newNoopMCPManager() mcp.ServerProxier { return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer("")) } +// callAccumulator tracks all tool invocations by name and each instance's arguments. +type callAccumulator struct { + calls map[string][]any + callsMu sync.Mutex + toolErrors map[string]string +} + func newCallAccumulator() *callAccumulator { return &callAccumulator{ calls: make(map[string][]any), diff --git a/internal/integrationtest/mockupstream.go b/internal/integrationtest/mockupstream.go index c0dbec0..a658b05 100644 --- a/internal/integrationtest/mockupstream.go +++ b/internal/integrationtest/mockupstream.go @@ -35,6 +35,34 @@ type upstreamResponse struct { OnRequest func(r *http.Request, body []byte) } +// newFixtureResponse creates an upstreamResponse from a parsed fixture archive. +// It reads whichever of 'streaming' and 'non-streaming' sections exist; +// not every fixture has both (e.g. error fixtures may only define one). +func newFixtureResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse + if fix.Has(fixtures.SectionStreaming) { + resp.Streaming = fix.Streaming() + } + if fix.Has(fixtures.SectionNonStreaming) { + resp.Blocking = fix.NonStreaming() + } + return resp +} + +// newFixtureToolResponse creates an upstreamResponse from the tool-call fixture files. +// It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' +// sections exist. +func newFixtureToolResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse + if fix.Has(fixtures.SectionStreamingToolCall) { + resp.Streaming = fix.StreamingToolCall() + } + if fix.Has(fixtures.SectionNonStreamToolCall) { + resp.Blocking = fix.NonStreamingToolCall() + } + return resp +} + // receivedRequest captures the details of a single request handled by mockUpstream. type receivedRequest struct { Method string @@ -70,34 +98,6 @@ type mockUpstream struct { responses []upstreamResponse } -// newFixtureResponse creates an upstreamResponse from a parsed fixture archive. -// It reads whichever of 'streaming' and 'non-streaming' sections exist; -// not every fixture has both (e.g. error fixtures may only define one). -func newFixtureResponse(fix fixtures.Fixture) upstreamResponse { - var resp upstreamResponse - if fix.Has(fixtures.SectionStreaming) { - resp.Streaming = fix.Streaming() - } - if fix.Has(fixtures.SectionNonStreaming) { - resp.Blocking = fix.NonStreaming() - } - return resp -} - -// newFixtureToolResponse creates an upstreamResponse from the tool-call fixture files. -// It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' -// sections exist. -func newFixtureToolResponse(fix fixtures.Fixture) upstreamResponse { - var resp upstreamResponse - if fix.Has(fixtures.SectionStreamingToolCall) { - resp.Streaming = fix.StreamingToolCall() - } - if fix.Has(fixtures.SectionNonStreamToolCall) { - resp.Blocking = fix.NonStreamingToolCall() - } - return resp -} - // receivedRequests returns a copy of all requests received so far. func (ms *mockUpstream) receivedRequests() []receivedRequest { ms.mu.Lock() From 4291a3a82922fa94039a83792a0c2af2de98cfbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 15:33:33 +0000 Subject: [PATCH 26/32] recorder cleanup --- internal/integrationtest/trace_test.go | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 7256ba9..504d717 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -123,9 +123,8 @@ func TestTraceAnthropic(t *testing.T) { defer resp.Body.Close() bridgeServer.Close() - recorder := bridgeServer.Recorder - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID model := gjson.Get(string(reqBody), "model").Str if tc.bedrock { @@ -239,9 +238,8 @@ func TestTraceAnthropicErr(t *testing.T) { defer resp.Body.Close() bridgeServer.Close() - recorder := bridgeServer.Recorder - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -479,9 +477,8 @@ func TestTraceOpenAI(t *testing.T) { defer resp.Body.Close() bridgeServer.Close() - recorder := bridgeServer.Recorder - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -637,9 +634,8 @@ func TestTraceOpenAIErr(t *testing.T) { defer resp.Body.Close() bridgeServer.Close() - recorder := bridgeServer.Recorder - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { From b0e32beccf70bf6f828cfdddd4eca5725bc68d86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 15:37:10 +0000 Subject: [PATCH 27/32] MockRecorder Total*Tokens --- internal/testutil/mock_recorder.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index 18d869a..09bcac3 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -76,8 +76,10 @@ func (m *MockRecorder) RecordedTokenUsages() []*recorder.TokenUsageRecord { // TotalInputTokens returns the sum of input tokens across all recorded token usages. func (m *MockRecorder) TotalInputTokens() int64 { + m.mu.Lock() + defer m.mu.Unlock() var total int64 - for _, el := range m.RecordedTokenUsages() { + for _, el := range m.tokenUsages { total += el.Input } return total @@ -85,8 +87,10 @@ func (m *MockRecorder) TotalInputTokens() int64 { // TotalOutputTokens returns the sum of output tokens across all recorded token usages. func (m *MockRecorder) TotalOutputTokens() int64 { + m.mu.Lock() + defer m.mu.Unlock() var total int64 - for _, el := range m.RecordedTokenUsages() { + for _, el := range m.tokenUsages { total += el.Output } return total From 8f4973075596651f28c15f602a397979eb5b87a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 16:41:40 +0000 Subject: [PATCH 28/32] add Body.Close() to bridgeServer.makeRequest + some test names cleanup --- internal/integrationtest/apidump_test.go | 4 +--- internal/integrationtest/bridge_test.go | 17 +---------------- .../integrationtest/circuit_breaker_test.go | 4 ---- internal/integrationtest/metrics_test.go | 5 ----- internal/integrationtest/responses_test.go | 6 ------ internal/integrationtest/setupbridge.go | 16 +++++----------- internal/integrationtest/trace_test.go | 8 +------- 7 files changed, 8 insertions(+), 52 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 2349640..fcc70db 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -76,7 +76,6 @@ func TestAPIDump(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) // Verify dump files were created. @@ -196,8 +195,7 @@ func TestAPIDumpPassthrough(t *testing.T) { withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)), ) - resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) - defer resp.Body.Close() + bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) // Find dump files in the passthrough directory. passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough") diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index b1991b3..a78ae3c 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -80,7 +80,6 @@ func TestAnthropicMessages(t *testing.T) { require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Response-specific checks. if tc.streaming { @@ -146,7 +145,6 @@ func TestAWSBedrockIntegration(t *testing.T) { ) resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) - defer resp.Body.Close() require.Equal(t, http.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -185,7 +183,6 @@ func TestAWSBedrockIntegration(t *testing.T) { reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) - defer resp.Body.Close() // For streaming responses, consume the body to allow the stream to complete. if streaming { @@ -259,7 +256,6 @@ func TestOpenAIChatCompletions(t *testing.T) { require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Response-specific checks. if tc.streaming { @@ -350,7 +346,6 @@ func TestOpenAIChatCompletions(t *testing.T) { // Consume the full response body to ensure the interception completes _, err = io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() // Verify the MCP tool was actually invoked invocations := mockMCP.getCallsByTool(mockToolName) @@ -515,7 +510,6 @@ func TestSimple(t *testing.T) { require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Then: I expect the upstream request to have the correct path. received := upstream.receivedRequests() @@ -629,7 +623,6 @@ func TestSessionIDTracking(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Drain the body to let the stream complete. _, err := io.ReadAll(resp.Body) @@ -710,7 +703,6 @@ func TestFallthrough(t *testing.T) { bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) - defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1050,7 +1042,6 @@ func TestErrorHandling(t *testing.T) { require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) - t.Cleanup(func() { _ = resp.Body.Close() }) tc.responseHandlerFn(resp) bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) @@ -1118,7 +1109,6 @@ func TestErrorHandling(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) - t.Cleanup(func() { _ = resp.Body.Close() }) tc.responseHandlerFn(resp) bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) @@ -1179,7 +1169,6 @@ func TestStableRequestEncoding(t *testing.T) { for range count { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() } // All upstream request bodies should be identical. @@ -1294,7 +1283,6 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() // Verify tool_choice in the upstream request. received := upstream.receivedRequests() @@ -1352,7 +1340,6 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) _, _ = io.ReadAll(resp.Body) - _ = resp.Body.Close() // Verify the thinking field was preserved in the upstream request. received := upstream.receivedRequests() @@ -1414,7 +1401,6 @@ func TestEnvironmentDoNotLeak(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Verify that environment values did not leak. received := upstream.receivedRequests() @@ -1536,9 +1522,8 @@ func TestActorHeaders(t *testing.T) { reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.NotEmpty(t, receivedHeaders) - defer resp.Body.Close() // Verify that the actor headers were only received if intended. found := make(map[string][]string) diff --git a/internal/integrationtest/circuit_breaker_test.go b/internal/integrationtest/circuit_breaker_test.go index 0ca7f59..4e39264 100644 --- a/internal/integrationtest/circuit_breaker_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -139,7 +139,6 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) _, err := io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() return resp } @@ -293,7 +292,6 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) _, err := io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() return resp } @@ -440,7 +438,6 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) _, err := io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() return resp } @@ -564,7 +561,6 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { }) _, err := io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() return resp } diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index 3aec692..ea9293e 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -126,7 +126,6 @@ func TestMetrics_Interception(t *testing.T) { ) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) - defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( @@ -207,7 +206,6 @@ func TestMetrics_PassthroughCount(t *testing.T) { ) resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) - defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues( @@ -231,7 +229,6 @@ func TestMetrics_PromptCount(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( @@ -255,7 +252,6 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues( @@ -285,7 +281,6 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() _, _ = io.ReadAll(resp.Body) // Wait until full roundtrip has completed. diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 612a821..2f4ddc7 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -337,7 +337,6 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}}) - defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) got, err := io.ReadAll(resp.Body) @@ -425,7 +424,6 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) - defer resp.Body.Close() require.Equal(t, "application/json", resp.Header.Get("Content-Type")) require.Equal(t, http.StatusNotImplemented, resp.StatusCode) @@ -531,7 +529,6 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, []byte(tc.request)) - defer resp.Body.Close() _, err := io.ReadAll(resp.Body) require.NoError(t, err) }) @@ -588,7 +585,6 @@ func TestClientAndConnectionError(t *testing.T) { reqBytes := responsesRequestBytes(t, tc.streaming) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) - defer resp.Body.Close() require.Equal(t, "application/json", resp.Header.Get("Content-Type")) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -666,7 +662,6 @@ func TestUpstreamError(t *testing.T) { reqBytes := responsesRequestBytes(t, tc.streaming) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) - defer resp.Body.Close() require.Equal(t, tc.statusCode, resp.StatusCode) require.Equal(t, tc.contentType, resp.Header.Get("Content-Type")) @@ -843,7 +838,6 @@ func TestResponsesInjectedTool(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP)) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) - defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index ced939b..1d2cdad 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -48,8 +48,7 @@ type bridgeConfig struct { mcpProxy mcp.ServerProxier userID string metadata recorder.Metadata - logger slog.Logger - loggerSet bool + logger slog.Logger } // bridgeTestServer wraps an httptest.Server running a RequestBridge. @@ -76,6 +75,7 @@ func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, } resp, err := http.DefaultClient.Do(req) require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) return resp } @@ -126,7 +126,7 @@ func withActor(id string, md recorder.Metadata) bridgeOption { // a RequestBridge with sensible defaults: // - All standard providers (unless withProvider / withCustomProvider) // - NoopMCPManager (unless withMCP) -// - slogtest debug logger (unless withLogger) +// - slogtest debug logger // - defaultTracer (unless withTracer) // - defaultActorID (unless withActor) func newBridgeTestServer( @@ -146,9 +146,7 @@ func newBridgeTestServer( if cfg.tracer == nil { cfg.tracer = defaultTracer } - if !cfg.loggerSet { - cfg.logger = newLogger(t) - } + cfg.logger = newLogger(t) if cfg.mcpProxy == nil { cfg.mcpProxy = newNoopMCPManager() } @@ -168,8 +166,7 @@ func newBridgeTestServer( } mockRec := &testutil.MockRecorder{} - var rec aibridge.Recorder = mockRec - rec = aibridge.NewRecorder(cfg.logger, cfg.tracer, func() (aibridge.Recorder, error) { + rec := aibridge.NewRecorder(cfg.logger, cfg.tracer, func() (aibridge.Recorder, error) { return mockRec, nil }) @@ -237,9 +234,6 @@ func setupInjectedToolTest( resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - t.Cleanup(func() { - _ = resp.Body.Close() - }) // We must ALWAYS have 2 calls to the bridge for injected tool tests. require.Eventually(t, func() bool { diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 504d717..19af949 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -120,7 +120,6 @@ func TestTraceAnthropic(t *testing.T) { require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() bridgeServer.Close() require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) @@ -235,7 +234,6 @@ func TestTraceAnthropicErr(t *testing.T) { } else { require.Equal(t, tc.expectCode, resp.StatusCode) } - defer resp.Body.Close() bridgeServer.Close() require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) @@ -350,11 +348,10 @@ func TestInjectedToolsTrace(t *testing.T) { validatorFn = openaiChatToolResultValidator(t) } - recorderClient, mockMCP, resp := setupInjectedToolTest( + recorderClient, mockMCP, _ := setupInjectedToolTest( t, tc.fixture, tc.streaming, tracer, tc.path, validatorFn, tc.opts..., ) - defer resp.Body.Close() require.Len(t, recorderClient.RecordedInterceptions(), 1) intcID := recorderClient.RecordedInterceptions()[0].ID @@ -474,7 +471,6 @@ func TestTraceOpenAI(t *testing.T) { require.NoError(t, err) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() bridgeServer.Close() require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) @@ -631,7 +627,6 @@ func TestTraceOpenAIErr(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.Equal(t, tc.expectCode, resp.StatusCode) - defer resp.Body.Close() bridgeServer.Close() require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) @@ -670,7 +665,6 @@ func TestTracePassthrough(t *testing.T) { ) resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) - defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) bridgeServer.Close() From cf4208e1770b89b8d29920e4b97045cb66bfb8f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Fri, 6 Mar 2026 17:12:22 +0000 Subject: [PATCH 29/32] changed TestResponsesParallelToolsOverwritten and TestOpenAIChatCompletions to use newBridgeTestServer + test some names cleanup --- internal/integrationtest/bridge_test.go | 41 ++++++++++----------- internal/integrationtest/responses_test.go | 42 +++++++++++----------- internal/integrationtest/trace_test.go | 3 +- 3 files changed, 42 insertions(+), 44 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index a78ae3c..28fb52c 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -6,9 +6,7 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" - "net/http/httptest" "strings" "testing" "time" @@ -44,18 +42,21 @@ func TestAnthropicMessages(t *testing.T) { t.Parallel() cases := []struct { + name string streaming bool expectedInputTokens int expectedOutputTokens int expectedToolCallID string }{ { + name: "streaming", streaming: true, expectedInputTokens: 2, expectedOutputTokens: 66, expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", }, { + name: "non-streaming", streaming: false, expectedInputTokens: 5, expectedOutputTokens: 84, @@ -64,7 +65,7 @@ func TestAnthropicMessages(t *testing.T) { } for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -221,17 +222,20 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Parallel() cases := []struct { + name string streaming bool expectedInputTokens, expectedOutputTokens int expectedToolCallID string }{ { + name: "streaming", streaming: true, expectedInputTokens: 60, expectedOutputTokens: 15, expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", }, { + name: "non-streaming", streaming: false, expectedInputTokens: 60, expectedOutputTokens: 15, @@ -240,7 +244,7 @@ func TestOpenAIChatCompletions(t *testing.T) { } for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -1498,32 +1502,29 @@ func TestActorHeaders(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Track headers received by the upstream server. - var receivedHeaders http.Header - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.WriteHeader(http.StatusTeapot) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) metadataKey := "Username" - bridgeServer := newBridgeTestServer(t, ctx, srv.URL, - withCustomProvider(tc.createProviderFn(srv.URL, apiKey, send)), + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(tc.createProviderFn(upstream.URL, apiKey, send)), withActor(defaultActorID, recorder.Metadata{ metadataKey: actorUsername, }), ) // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) - require.NotEmpty(t, receivedHeaders) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + // Drain the body so streaming responses complete without + // a "connection reset" error in the mock upstream. + _, _ = io.ReadAll(resp.Body) + + received := upstream.receivedRequests() + require.NotEmpty(t, received) + receivedHeaders := received[0].Header // Verify that the actor headers were only received if intended. found := make(map[string][]string) diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 2f4ddc7..47cb0a8 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -506,31 +506,29 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - defer r.Body.Close() - - var receivedRequest map[string]any - require.NoError(t, json.Unmarshal(raw, &receivedRequest)) - if tc.expectParallelToolCalls { - parallelToolCalls, ok := receivedRequest["parallel_tool_calls"].(bool) - require.True(t, ok, "parallel_tool_calls should be present in upstream request") - require.Equal(t, tc.expectParallelToolCallsValue, parallelToolCalls) - } else { - _, ok := receivedRequest["parallel_tool_calls"] - require.False(t, ok, "parallel_tool_calls should not be present when not set") - } - - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(upstream.Close) - + fix := fixtures.OaiResponsesBlockingSimple + if tc.streaming { + fix = fixtures.OaiResponsesStreamingSimple + } + upstream := newMockUpstream(t, ctx, newFixtureResponse(fixtures.Parse(t, fix))) bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, []byte(tc.request)) - _, err := io.ReadAll(resp.Body) - require.NoError(t, err) + _, _ = io.ReadAll(resp.Body) + + received := upstream.receivedRequests() + require.Len(t, received, 1) + + var receivedRequest map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) + if tc.expectParallelToolCalls { + parallelToolCalls, ok := receivedRequest["parallel_tool_calls"].(bool) + require.True(t, ok, "parallel_tool_calls should be present in upstream request") + require.Equal(t, tc.expectParallelToolCallsValue, parallelToolCalls) + } else { + _, ok := receivedRequest["parallel_tool_calls"] + require.False(t, ok, "parallel_tool_calls should not be present when not set") + } }) } } diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 19af949..d3f2b3d 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -2,7 +2,6 @@ package integrationtest import ( "context" - "fmt" "net/http" "slices" "strings" @@ -100,7 +99,7 @@ func TestTraceAnthropic(t *testing.T) { fixtureReqBody := fix.Request() for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) From 1086b3b3abea877ac3c588a052ac26e0dc84ecd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 9 Mar 2026 08:18:58 +0000 Subject: [PATCH 30/32] fmt fix --- internal/integrationtest/setupbridge.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 1d2cdad..b47dc1c 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -48,7 +48,7 @@ type bridgeConfig struct { mcpProxy mcp.ServerProxier userID string metadata recorder.Metadata - logger slog.Logger + logger slog.Logger } // bridgeTestServer wraps an httptest.Server running a RequestBridge. From 41e95988658b0abfade0379636f1849fb67a8930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 9 Mar 2026 16:37:34 +0000 Subject: [PATCH 31/32] review 1: fixed setting headers in makeRequest, removed unneded providerName from TestFallthrough --- internal/integrationtest/bridge_test.go | 17 ++++++----------- internal/integrationtest/setupbridge.go | 2 +- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 28fb52c..8114424 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -653,48 +653,43 @@ func TestFallthrough(t *testing.T) { testCases := []struct { name string - providerName string fixture []byte basePath string requestPath string expectedUpstreamPath string - authHeader string + expectAuthHeader string }{ { name: "ant_empty_base_url_path", - providerName: config.ProviderAnthropic, fixture: fixtures.AntFallthrough, basePath: "", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/v1/models", - authHeader: "X-Api-Key", + expectAuthHeader: "X-Api-Key", }, { name: "oai_empty_base_url_path", - providerName: config.ProviderOpenAI, fixture: fixtures.OaiChatFallthrough, basePath: "", requestPath: "/openai/v1/models", expectedUpstreamPath: "/models", - authHeader: "Authorization", + expectAuthHeader: "Authorization", }, { name: "ant_some_base_url_path", - providerName: config.ProviderAnthropic, fixture: fixtures.AntFallthrough, basePath: "/api", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/api/v1/models", - authHeader: "X-Api-Key", + expectAuthHeader: "X-Api-Key", }, { name: "oai_some_base_url_path", - providerName: config.ProviderOpenAI, fixture: fixtures.OaiChatFallthrough, basePath: "/api", requestPath: "/openai/v1/models", expectedUpstreamPath: "/api/models", - authHeader: "Authorization", + expectAuthHeader: "Authorization", }, } @@ -715,7 +710,7 @@ func TestFallthrough(t *testing.T) { received := upstream.receivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(tc.authHeader), apiKey) + require.Contains(t, received[0].Header.Get(tc.expectAuthHeader), apiKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index b47dc1c..0844f24 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -69,7 +69,7 @@ func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, for _, h := range header { for k, vals := range h { for _, v := range vals { - req.Header.Set(k, v) + req.Header.Add(k, v) } } } From 0f255184d2f2df7810b1f17f9fb99d23488aebfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Tue, 10 Mar 2026 11:36:09 +0000 Subject: [PATCH 32/32] review 2 --- internal/integrationtest/apidump_test.go | 5 +- internal/integrationtest/bridge_test.go | 73 +++++++++++----------- internal/integrationtest/metrics_test.go | 15 +++-- internal/integrationtest/responses_test.go | 7 ++- internal/integrationtest/setupbridge.go | 9 +-- internal/integrationtest/trace_test.go | 6 +- 6 files changed, 62 insertions(+), 53 deletions(-) diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index fcc70db..77a4ea1 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -76,7 +76,8 @@ func TestAPIDump(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) // Verify dump files were created. interceptions := bridgeServer.Recorder.RecordedInterceptions() @@ -85,7 +86,7 @@ func TestAPIDump(t *testing.T) { // Find dump files for this interception by walking the dump directory. var reqDumpFile, respDumpFile string - err := filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error { + err = filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 8114424..04f5ad1 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -562,19 +562,19 @@ func TestSessionIDTracking(t *testing.T) { t.Parallel() testCases := []struct { - name string - fixture []byte - expectedClient aibridge.Client - sessionID string - header http.Header - mutateBody func(t *testing.T, body []byte) []byte + name string + fixture []byte + header http.Header + metadataSessionID string + expectedClient aibridge.Client + expectSessionID string }{ // Session in header. { - name: "mux", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientMux, - sessionID: "mux-workspace-321", + name: "mux", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientMux, + expectSessionID: "mux-workspace-321", header: http.Header{ "User-Agent": []string{"mux/1.0.0"}, "X-Mux-Workspace-Id": []string{"mux-workspace-321"}, @@ -582,21 +582,14 @@ func TestSessionIDTracking(t *testing.T) { }, // Session in body. { - name: "claude_code", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientClaudeCode, - sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", + name: "claude_code", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientClaudeCode, + expectSessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", header: http.Header{ "User-Agent": []string{"claude-cli/2.0.67 (external, cli)"}, }, - mutateBody: func(t *testing.T, body []byte) []byte { - t.Helper() - // Claude Code embeds the session ID in metadata.user_id within the body. - body, err := sjson.SetBytes(body, "metadata.user_id", - "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479") - require.NoError(t, err) - return body - }, + metadataSessionID: "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479", }, // No session. { @@ -621,8 +614,10 @@ func TestSessionIDTracking(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic)) reqBody := fix.Request() - if tc.mutateBody != nil { - reqBody = tc.mutateBody(t, reqBody) + if tc.metadataSessionID != "" { + var err error + reqBody, err = sjson.SetBytes(reqBody, "metadata.user_id", tc.metadataSessionID) + require.NoError(t, err) } resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) @@ -636,11 +631,11 @@ func TestSessionIDTracking(t *testing.T) { require.Len(t, interceptions, 1, "expected exactly one interception") assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) - if tc.sessionID == "" { + if tc.expectSessionID == "" { assert.Nil(t, interceptions[0].ClientSessionID, "expected nil session ID for %s", tc.name) } else { require.NotNil(t, interceptions[0].ClientSessionID, "expected non-nil session ID for %s", tc.name) - assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID) + assert.Equal(t, tc.expectSessionID, *interceptions[0].ClientSessionID) } bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) @@ -733,10 +728,10 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t)) + bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) @@ -798,11 +793,11 @@ func TestAnthropicInjectedTools(t *testing.T) { assert.EqualValues(t, 204, message.Usage.OutputTokens) // Ensure tokens used during injected tool invocation are accounted for. - assert.EqualValues(t, 15308, recorderClient.TotalInputTokens()) - assert.EqualValues(t, 204, recorderClient.TotalOutputTokens()) + assert.EqualValues(t, 15308, bridgeServer.Recorder.TotalInputTokens()) + assert.EqualValues(t, 204, bridgeServer.Recorder.TotalOutputTokens()) // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) }) } @@ -816,10 +811,10 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) + bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) @@ -896,11 +891,11 @@ func TestOpenAIInjectedTools(t *testing.T) { assert.EqualValues(t, 105, message.Usage.CompletionTokens) // Ensure tokens used during injected tool invocation are accounted for. - require.EqualValues(t, 5047, recorderClient.TotalInputTokens()) - require.EqualValues(t, 105, recorderClient.TotalOutputTokens()) + require.EqualValues(t, 5047, bridgeServer.Recorder.TotalInputTokens()) + require.EqualValues(t, 105, bridgeServer.Recorder.TotalOutputTokens()) // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) }) } @@ -1338,7 +1333,8 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) // Verify the thinking field was preserved in the upstream request. received := upstream.receivedRequests() @@ -1515,7 +1511,8 @@ func TestActorHeaders(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) // Drain the body so streaming responses complete without // a "connection reset" error in the mock upstream. - _, _ = io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) received := upstream.receivedRequests() require.NotEmpty(t, received) diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index ea9293e..23e3847 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -126,7 +126,8 @@ func TestMetrics_Interception(t *testing.T) { ) resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", defaultActorID)) @@ -167,7 +168,8 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) } }() @@ -229,7 +231,8 @@ func TestMetrics_PromptCount(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( config.ProviderOpenAI, "gpt-4.1", defaultActorID)) @@ -252,7 +255,8 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues( config.ProviderOpenAI, "gpt-4.1", "read_file")) @@ -281,7 +285,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) // Wait until full roundtrip has completed. require.Eventually(t, func() bool { diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 47cb0a8..483ce90 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -514,7 +514,8 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, []byte(tc.request)) - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) received := upstream.receivedRequests() require.Len(t, received, 1) @@ -932,6 +933,10 @@ func startRejectingListener(t *testing.T) (addr string) { return } + // Read at least 1 byte so the client has started writing + // before we RST, ensuring a consistent "connection reset by peer". + buf := make([]byte, 1) + _, _ = c.Read(buf) if tc, ok := c.(*net.TCPConn); ok { _ = tc.SetLinger(0) } diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 0844f24..1d061c2 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -203,7 +203,7 @@ func setupInjectedToolTest( path string, toolRequestValidatorFn func(*http.Request, []byte), opts ...bridgeOption, -) (*testutil.MockRecorder, *mockMCP, *http.Response) { +) (*bridgeTestServer, *mockMCP, *http.Response) { t.Helper() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -212,7 +212,8 @@ func setupInjectedToolTest( fix := fixtures.Parse(t, fixture) // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. + // First request → tool call response + // Second request → final response. firstResp := newFixtureResponse(fix) toolResp := newFixtureToolResponse(fix) toolResp.OnRequest = toolRequestValidatorFn @@ -235,12 +236,12 @@ func setupInjectedToolTest( resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - // We must ALWAYS have 2 calls to the bridge for injected tool tests. + // Wait both requests (initial + tool call result) require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) - return bridgeServer.Recorder, mockMCP, resp + return bridgeServer, mockMCP, resp } // newDefaultProvider creates a Provider with default test configuration. diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index d3f2b3d..88bec31 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -347,13 +347,13 @@ func TestInjectedToolsTrace(t *testing.T) { validatorFn = openaiChatToolResultValidator(t) } - recorderClient, mockMCP, _ := setupInjectedToolTest( + bridgeServer, mockMCP, _ := setupInjectedToolTest( t, tc.fixture, tc.streaming, tracer, tc.path, validatorFn, tc.opts..., ) - require.Len(t, recorderClient.RecordedInterceptions(), 1) - intcID := recorderClient.RecordedInterceptions()[0].ID + require.Len(t, bridgeServer.Recorder.RecordedInterceptions(), 1) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID tool := mockMCP.ListTools()[0]