diff --git a/apidump_integration_test.go b/internal/integrationtest/apidump_test.go similarity index 66% rename from apidump_integration_test.go rename to internal/integrationtest/apidump_test.go index 8aac244..77a4ea1 100644 --- a/apidump_integration_test.go +++ b/internal/integrationtest/apidump_test.go @@ -1,11 +1,10 @@ -package aibridge_test +package integrationtest import ( "bufio" "bytes" "context" "io" - "net" "net/http" "net/http/httptest" "os" @@ -14,66 +13,46 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept/apidump" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" "github.com/stretchr/testify/require" ) -func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { - return config.OpenAI{ - BaseURL: url, - Key: key, - APIDumpDir: dumpDir, - } -} - -func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { - return config.Anthropic{ - BaseURL: url, - Key: key, - APIDumpDir: dumpDir, - } -} - func TestAPIDump(t *testing.T) { t.Parallel() cases := []struct { - name string - fixture []byte - providersFunc func(addr, dumpDir string) []aibridge.Provider - createRequestFunc createRequestFunc + name string + fixture []byte + providerFunc func(addr, dumpDir string) aibridge.Provider + path string }{ { name: "anthropic", fixture: fixtures.AntSimple, - providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, - createRequestFunc: createAnthropicMessagesReq, + path: pathAnthropicMessages, }, { name: "openai_chat_completions", fixture: fixtures.OaiChatSimple, - providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - createRequestFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, }, { name: "openai_responses", fixture: fixtures.OaiResponsesBlockingSimple, - providersFunc: func(addr, dumpDir string) []aibridge.Provider { - return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - createRequestFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, }, } @@ -81,38 +60,27 @@ func TestAPIDump(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) // Setup mock upstream server. fix := fixtures.Parse(t, tc.fixture) - srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // Create temp dir for API dumps. dumpDir := t.TempDir() - recorderClient := &testutil.MockRecorder{} - b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, srv.URL, + withCustomProvider(tc.providerFunc(srv.URL, dumpDir)), + ) - req := tc.createRequestFunc(t, mockSrv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) // Verify dump files were created. - interceptions := recorderClient.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) interceptionID := interceptions[0].ID @@ -167,7 +135,7 @@ func TestAPIDump(t *testing.T) { expectedRespBody := fix.NonStreaming() require.JSONEq(t, string(expectedRespBody), string(dumpRespBody), "response body JSON should match semantically") - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } } @@ -213,8 +181,6 @@ func TestAPIDumpPassthrough(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) @@ -226,30 +192,16 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() - recorderClient := &testutil.MockRecorder{} - prov := tc.providerFunc(upstream.URL, dumpDir) - provs := []aibridge.Provider{prov} - b, err := aibridge.NewRequestBridge(t.Context(), provs, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - bridgeSrv := httptest.NewUnstartedServer(b) - t.Cleanup(bridgeSrv.Close) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)), + ) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, bridgeSrv.URL+tc.requestPath, nil) - require.NoError(t, err) - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) // Find dump files in the passthrough directory. passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough") var reqDumpFile, respDumpFile string - err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { + err := filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } diff --git a/bridge_integration_test.go b/internal/integrationtest/bridge_test.go similarity index 56% rename from bridge_integration_test.go rename to internal/integrationtest/bridge_test.go index 5eed920..04f5ad1 100644 --- a/bridge_integration_test.go +++ b/internal/integrationtest/bridge_test.go @@ -1,4 +1,4 @@ -package aibridge_test +package integrationtest import ( "bytes" @@ -6,24 +6,18 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" - "net/http/httptest" "strings" "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" @@ -34,35 +28,9 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" "go.uber.org/goleak" ) -var testTracer = otel.Tracer("forTesting") - -const ( - apiKey = "api-key" - userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" -) - -type ( - providerFunc func(addr string) aibridge.Provider - createRequestFunc func(*testing.T, string, []byte) *http.Request -) - -func newAnthropicProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) -} - -func newOpenAIProvider(addr string) aibridge.Provider { - return provider.NewOpenAI(openaiCfg(addr, apiKey)) -} - -func newBedrockProvider(addr string) aibridge.Provider { - return provider.NewAnthropic(anthropicCfg(addr, apiKey), testBedrockCfg(addr)) -} - func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } @@ -74,18 +42,21 @@ func TestAnthropicMessages(t *testing.T) { t.Parallel() cases := []struct { + name string streaming bool expectedInputTokens int expectedOutputTokens int expectedToolCallID string }{ { + name: "streaming", streaming: true, expectedInputTokens: 2, expectedOutputTokens: 66, expectedToolCallID: "toolu_01RX68weRSquLx6HUTj65iBo", }, { + name: "non-streaming", streaming: false, expectedInputTokens: 5, expectedOutputTokens: 84, @@ -94,37 +65,22 @@ func TestAnthropicMessages(t *testing.T) { } for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Response-specific checks. if tc.streaming { @@ -141,13 +97,13 @@ func TestAnthropicMessages(t *testing.T) { // One for message_start, one for message_delta. expectedTokenRecordings = 2 } - tokenUsages := recorderClient.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, expectedTokenRecordings) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "Read", toolUsages[0].Tool) assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) @@ -157,11 +113,11 @@ func TestAnthropicMessages(t *testing.T) { require.Contains(t, args, "file_path") assert.Equal(t, "/tmp/blah/foo", args["file_path"]) - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -185,24 +141,11 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ - provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg), - }, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) + bridgeServer := newBridgeTestServer(t, ctx, "http://unused", + withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), + ) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() - - req := createAnthropicMessagesReq(t, mockSrv.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -220,7 +163,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ @@ -232,29 +175,15 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge( - ctx, []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, - recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - mockBridgeSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockBridgeSrv.Close) - mockBridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockBridgeSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), + ) // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, mockBridgeSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) // For streaming responses, consume the body to allow the stream to complete. if streaming { @@ -265,7 +194,7 @@ func TestAWSBedrockIntegration(t *testing.T) { // Verify that Bedrock-specific model name was used in the request to the mock server // and the interception data. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" @@ -277,10 +206,10 @@ func TestAWSBedrockIntegration(t *testing.T) { require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") - interceptions := recorderClient.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -293,17 +222,20 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Parallel() cases := []struct { + name string streaming bool expectedInputTokens, expectedOutputTokens int expectedToolCallID string }{ { + name: "streaming", streaming: true, expectedInputTokens: 60, expectedOutputTokens: 15, expectedToolCallID: "call_HjeqP7YeRkoNj0de9e3U4X4B", }, { + name: "non-streaming", streaming: false, expectedInputTokens: 60, expectedOutputTokens: 15, @@ -312,38 +244,22 @@ func TestOpenAIChatCompletions(t *testing.T) { } for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Response-specific checks. if tc.streaming { @@ -359,12 +275,12 @@ func TestOpenAIChatCompletions(t *testing.T) { assert.Equal(t, "[DONE]", lastEvent.Data) } - tokenUsages := recorderClient.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, 1) - assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(tokenUsages), "input tokens miscalculated") - assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(tokenUsages), "output tokens miscalculated") + assert.EqualValues(t, tc.expectedInputTokens, bridgeServer.Recorder.TotalInputTokens(), "input tokens miscalculated") + assert.EqualValues(t, tc.expectedOutputTokens, bridgeServer.Recorder.TotalOutputTokens(), "output tokens miscalculated") - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) assert.Equal(t, "read_file", toolUsages[0].Tool) assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID) @@ -372,11 +288,11 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Contains(t, toolUsages[0].Args, "path") assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"]) - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -411,33 +327,19 @@ func TestOpenAIChatCompletions(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // Setup MCP proxies with the tool from the fixture - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := setupMCPForTest(t, defaultTracer) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} - b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mockMCP, logger, nil, testTracer) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mockMCP), + ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) // Verify SSE headers are sent correctly @@ -448,10 +350,9 @@ func TestOpenAIChatCompletions(t *testing.T) { // Consume the full response body to ensure the interception completes _, err = io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() // Verify the MCP tool was actually invoked - invocations := mockMCP.GetCallsByTool(testutil.MockToolName) + invocations := mockMCP.getCallsByTool(mockToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked") // Verify tool was invoked with the expected args (if specified) @@ -464,11 +365,11 @@ func TestOpenAIChatCompletions(t *testing.T) { } // Verify tool usage was recorded - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) - assert.Equal(t, testutil.MockToolName, toolUsages[0].Tool) + assert.Equal(t, mockToolName, toolUsages[0].Tool) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -535,29 +436,13 @@ func TestSimple(t *testing.T) { return message.ID, nil } - // Common configuration functions for each provider type. - configureAnthropic := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - } - - configureOpenAI := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - } - testCases := []struct { name string fixture []byte basePath string expectedPath string - configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error) getResponseIDFunc func(streaming bool, resp *http.Response) (string, error) - createRequest func(*testing.T, string, []byte) *http.Request + path string expectedMsgID string userAgent string expectedClient aibridge.Client @@ -567,9 +452,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.AntSimple, basePath: "", expectedPath: "/v1/messages", - configureFunc: configureAnthropic, getResponseIDFunc: getAnthropicResponseID, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "claude-cli/2.0.67 (external, cli)", expectedClient: aibridge.ClientClaudeCode, @@ -579,9 +463,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.OaiChatSimple, basePath: "", expectedPath: "/chat/completions", - configureFunc: configureOpenAI, getResponseIDFunc: getOpenAIResponseID, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64)", expectedClient: aibridge.ClientCodex, @@ -591,9 +474,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.AntSimple, basePath: "/api", expectedPath: "/api/v1/messages", - configureFunc: configureAnthropic, getResponseIDFunc: getAnthropicResponseID, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn", userAgent: "GitHubCopilotChat/0.37.2026011603", expectedClient: aibridge.ClientCopilotVSC, @@ -603,9 +485,8 @@ func TestSimple(t *testing.T) { fixture: fixtures.OaiChatSimple, basePath: "/api", expectedPath: "/api/chat/completions", - configureFunc: configureOpenAI, getResponseIDFunc: getOpenAIResponseID, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", expectedClient: aibridge.ClientZed, @@ -624,33 +505,18 @@ func TestSimple(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - b, err := tc.configureFunc(t, upstream.URL+tc.basePath, recorderClient) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) // When: calling the "API server" with the fixture's request body. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := tc.createRequest(t, mockSrv.URL, reqBody) - req.Header.Set("User-Agent", tc.userAgent) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Then: I expect the upstream request to have the correct path. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedPath, received[0].Path) @@ -663,7 +529,7 @@ func TestSimple(t *testing.T) { resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Then: I expect the prompt to have been tracked. - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.NotEmpty(t, promptUsages, "no prompts tracked") assert.Contains(t, promptUsages[0].Prompt, "how many angels can dance on the head of a pin") @@ -674,17 +540,18 @@ func TestSimple(t *testing.T) { require.NoError(t, err, "failed to retrieve response ID") require.Nilf(t, uuid.Validate(id), "%s is not a valid UUID", id) - tokenUsages := recorderClient.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.GreaterOrEqual(t, len(tokenUsages), 1) require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) // Validate user agent and client have been recorded. - interceptions := recorderClient.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception, got: %v", interceptions) + assert.Equal(t, id, interceptions[0].ID) assert.Equal(t, tc.userAgent, interceptions[0].UserAgent) assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -695,72 +562,42 @@ func TestSessionIDTracking(t *testing.T) { t.Parallel() testCases := []struct { - name string - fixture []byte - expectedClient aibridge.Client - sessionID string - configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error) - createRequest func(t *testing.T, baseURL string, body []byte) *http.Request + name string + fixture []byte + header http.Header + metadataSessionID string + expectedClient aibridge.Client + expectSessionID string }{ // Session in header. { - name: "mux", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientMux, - sessionID: "mux-workspace-321", - configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - }, - createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request { - t.Helper() - req := createAnthropicMessagesReq(t, baseURL, body) - req.Header.Set("User-Agent", "mux/1.0.0") - req.Header.Set("X-Mux-Workspace-Id", "mux-workspace-321") - return req + name: "mux", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientMux, + expectSessionID: "mux-workspace-321", + header: http.Header{ + "User-Agent": []string{"mux/1.0.0"}, + "X-Mux-Workspace-Id": []string{"mux-workspace-321"}, }, }, // Session in body. { - name: "claude_code", - fixture: fixtures.AntSimple, - expectedClient: aibridge.ClientClaudeCode, - sessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", - configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - }, - createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request { - t.Helper() - // Claude Code embeds the session ID in metadata.user_id within the body. - body, err := sjson.SetBytes(body, "metadata.user_id", - "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479") - require.NoError(t, err) - req := createAnthropicMessagesReq(t, baseURL, body) - req.Header.Set("User-Agent", "claude-cli/2.0.67 (external, cli)") - return req + name: "claude_code", + fixture: fixtures.AntSimple, + expectedClient: aibridge.ClientClaudeCode, + expectSessionID: "f47ac10b-58cc-4372-a567-0e02b2c3d479", + header: http.Header{ + "User-Agent": []string{"claude-cli/2.0.67 (external, cli)"}, }, + metadataSessionID: "user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479", }, // No session. { name: "zed", fixture: fixtures.AntSimple, expectedClient: aibridge.ClientZed, - configureFunc: func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) - }, - createRequest: func(t *testing.T, baseURL string, body []byte) *http.Request { - t.Helper() - req := createAnthropicMessagesReq(t, baseURL, body) - req.Header.Set("User-Agent", "Zed/0.219.4+stable.119.abc123 (macos; aarch64)") - return req + header: http.Header{ + "User-Agent": []string{"Zed/0.219.4+stable.119.abc123 (macos; aarch64)"}, }, }, } @@ -773,41 +610,35 @@ func TestSessionIDTracking(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic)) - b, err := tc.configureFunc(t, upstream.URL, recorderClient) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) + reqBody := fix.Request() + if tc.metadataSessionID != "" { + var err error + reqBody, err = sjson.SetBytes(reqBody, "metadata.user_id", tc.metadataSessionID) + require.NoError(t, err) } - mockSrv.Start() - req := tc.createRequest(t, mockSrv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Drain the body to let the stream complete. - _, err = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - interceptions := recorderClient.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1, "expected exactly one interception") assert.Equal(t, string(tc.expectedClient), interceptions[0].Client) - if tc.sessionID == "" { + if tc.expectSessionID == "" { assert.Nil(t, interceptions[0].ClientSessionID, "expected nil session ID for %s", tc.name) } else { require.NotNil(t, interceptions[0].ClientSessionID, "expected non-nil session ID for %s", tc.name) - assert.Equal(t, tc.sessionID, *interceptions[0].ClientSessionID) + assert.Equal(t, tc.expectSessionID, *interceptions[0].ClientSessionID) } - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } } @@ -817,72 +648,43 @@ func TestFallthrough(t *testing.T) { testCases := []struct { name string - providerName string fixture []byte basePath string requestPath string expectedUpstreamPath string - configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) + expectAuthHeader string }{ { name: "ant_empty_base_url_path", - providerName: config.ProviderAnthropic, fixture: fixtures.AntFallthrough, basePath: "", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/v1/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + expectAuthHeader: "X-Api-Key", }, { name: "oai_empty_base_url_path", - providerName: config.ProviderOpenAI, fixture: fixtures.OaiChatFallthrough, basePath: "", requestPath: "/openai/v1/models", expectedUpstreamPath: "/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + expectAuthHeader: "Authorization", }, { name: "ant_some_base_url_path", - providerName: config.ProviderAnthropic, fixture: fixtures.AntFallthrough, basePath: "/api", requestPath: "/anthropic/v1/models", expectedUpstreamPath: "/api/v1/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + expectAuthHeader: "X-Api-Key", }, { name: "oai_some_base_url_path", - providerName: config.ProviderOpenAI, fixture: fixtures.OaiChatFallthrough, basePath: "/api", requestPath: "/openai/v1/models", expectedUpstreamPath: "/api/models", - configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) - bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - return provider, bridge - }, + expectAuthHeader: "Authorization", }, } @@ -891,32 +693,19 @@ func TestFallthrough(t *testing.T) { t.Parallel() fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, t.Context(), testutil.NewFixtureResponse(fix)) - recorderClient := &testutil.MockRecorder{} - provider, bridge := tc.configureFunc(upstream.URL+tc.basePath, recorderClient) + upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(t.Context(), userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) - - req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s%s", bridgeSrv.URL, tc.requestPath), nil) - require.NoError(t, err) - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) require.Equal(t, http.StatusOK, resp.StatusCode) // Verify upstream received the request at the expected path // with the API key header. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) require.Equal(t, tc.expectedUpstreamPath, received[0].Path) - require.Contains(t, received[0].Header.Get(provider.AuthHeader()), apiKey) + require.Contains(t, received[0].Header.Get(tc.expectAuthHeader), apiKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -939,18 +728,18 @@ func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, newAnthropicProvider, testTracer, userID, createAnthropicMessagesReq, anthropicToolResultValidator(t)) + bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, testutil.MockToolName, toolUsages[0].Tool) + require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(testutil.MockToolName) + invocations := mockMCP.getCallsByTool(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1004,12 +793,11 @@ func TestAnthropicInjectedTools(t *testing.T) { assert.EqualValues(t, 204, message.Usage.OutputTokens) // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - assert.EqualValues(t, 15308, calculateTotalInputTokens(tokenUsages)) - assert.EqualValues(t, 204, calculateTotalOutputTokens(tokenUsages)) + assert.EqualValues(t, 15308, bridgeServer.Recorder.TotalInputTokens()) + assert.EqualValues(t, 204, bridgeServer.Recorder.TotalOutputTokens()) // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) }) } @@ -1023,18 +811,18 @@ func TestOpenAIInjectedTools(t *testing.T) { t.Parallel() // Build the requirements & make the assertions which are common to all providers. - recorderClient, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, newOpenAIProvider, testTracer, userID, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) + bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. - toolUsages := recorderClient.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) - require.Equal(t, testutil.MockToolName, toolUsages[0].Tool) + require.Equal(t, mockToolName, toolUsages[0].Tool) expected, err := json.Marshal(map[string]any{"owner": "admin"}) require.NoError(t, err) actual, err := json.Marshal(toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mockMCP.GetCallsByTool(testutil.MockToolName) + invocations := mockMCP.getCallsByTool(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -1103,12 +891,11 @@ func TestOpenAIInjectedTools(t *testing.T) { assert.EqualValues(t, 105, message.Usage.CompletionTokens) // Ensure tokens used during injected tool invocation are accounted for. - tokenUsages := recorderClient.RecordedTokenUsages() - require.EqualValues(t, 5047, calculateTotalInputTokens(tokenUsages)) - require.EqualValues(t, 105, calculateTotalOutputTokens(tokenUsages)) + require.EqualValues(t, 5047, bridgeServer.Recorder.TotalInputTokens()) + require.EqualValues(t, 105, bridgeServer.Recorder.TotalOutputTokens()) // Ensure we received exactly one prompt. - promptUsages := recorderClient.RecordedPromptUsages() + promptUsages := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, promptUsages, 1) }) } @@ -1187,75 +974,6 @@ func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { } } -// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. -func setupInjectedToolTest( - t *testing.T, - fixture []byte, - streaming bool, - providerFn providerFunc, - tracer trace.Tracer, - userID string, - createRequestFn func(*testing.T, string, []byte) *http.Request, - toolRequestValidatorFn func(*http.Request, []byte), -) (*testutil.MockRecorder, *testutil.MockMCP, *http.Response) { - t.Helper() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.Parse(t, fixture) - - // Setup mock server for multi-turn interaction. - // First request → tool call response, second → tool response. - firstResp := testutil.NewFixtureResponse(fix) - toolResp := testutil.NewFixtureToolResponse(fix) - toolResp.OnRequest = toolRequestValidatorFn - upstream := testutil.NewMockUpstream(t, ctx, firstResp, toolResp) - - recorderClient := &testutil.MockRecorder{} - - mockMCP := testutil.SetupMCPForTest(t, tracer) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge( - t.Context(), - []aibridge.Provider{providerFn(upstream.URL)}, - recorderClient, - mockMCP, - logger, - nil, - tracer, - ) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) - - // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) - require.NoError(t, err) - - req := createRequestFn(t, bridgeSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - t.Cleanup(func() { - _ = resp.Body.Close() - }) - - // We must ALWAYS have 2 calls to the bridge for injected tool tests. - require.Eventually(t, func() bool { - return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) - - return recorderClient, mockMCP, resp -} - func TestErrorHandling(t *testing.T) { t.Parallel() @@ -1264,19 +982,13 @@ func TestErrorHandling(t *testing.T) { cases := []struct { name string fixture []byte - createRequestFunc createRequestFunc - configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) + path string responseHandlerFn func(resp *http.Response) }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntNonStreamError, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + name: config.ProviderAnthropic, + fixture: fixtures.AntNonStreamError, + path: pathAnthropicMessages, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -1287,14 +999,9 @@ func TestErrorHandling(t *testing.T) { }, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatNonStreamError, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatNonStreamError, + path: pathOpenAIChatCompletions, responseHandlerFn: func(resp *http.Response) { require.Equal(t, http.StatusBadRequest, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -1320,32 +1027,18 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. Error fixtures contain raw HTTP // responses that may cause the bridge to retry. fix := fixtures.Parse(t, tc.fixture) - mockSrv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorderClient := &testutil.MockRecorder{} + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - b, err := tc.configureFunc(mockSrv.URL, recorderClient, testutil.NewNoopMCPManager()) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - t.Cleanup(func() { _ = resp.Body.Close() }) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) tc.responseHandlerFn(resp) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1357,19 +1050,13 @@ func TestErrorHandling(t *testing.T) { cases := []struct { name string fixture []byte - createRequestFunc createRequestFunc - configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) + path string responseHandlerFn func(resp *http.Response) }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntMidStreamError, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + name: config.ProviderAnthropic, + fixture: fixtures.AntMidStreamError, + path: pathAnthropicMessages, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1381,14 +1068,9 @@ func TestErrorHandling(t *testing.T) { }, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatMidStreamError, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatMidStreamError, + path: pathOpenAIChatCompletions, responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1415,30 +1097,15 @@ func TestErrorHandling(t *testing.T) { // Setup mock server. fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - recorderClient := &testutil.MockRecorder{} + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - b, err := tc.configureFunc(upstream.URL, recorderClient, testutil.NewNoopMCPManager()) - require.NoError(t, err) - - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) - - req := tc.createRequestFunc(t, bridgeSrv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - t.Cleanup(func() { _ = resp.Body.Close() }) - require.NoError(t, err) - bridgeSrv.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) tc.responseHandlerFn(resp) - recorderClient.VerifyAllInterceptionsEnded(t) + bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) }) } }) @@ -1451,31 +1118,20 @@ func TestErrorHandling(t *testing.T) { func TestStableRequestEncoding(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - cases := []struct { - name string - fixture []byte - createRequestFunc createRequestFunc - configureFunc func(string, aibridge.Recorder, mcp.ServerProxier) (*aibridge.RequestBridge, error) + name string + fixture []byte + path string }{ { - name: config.ProviderAnthropic, - fixture: fixtures.AntSimple, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + path: pathAnthropicMessages, }, { - name: config.ProviderOpenAI, - fixture: fixtures.OaiChatSimple, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr mcp.ServerProxier) (*aibridge.RequestBridge, error) { - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, logger, nil, testTracer) - }, + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + path: pathOpenAIChatCompletions, }, } @@ -1487,42 +1143,30 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := setupMCPForTest(t, defaultTracer) fix := fixtures.Parse(t, tc.fixture) // Create a mock upstream that serves the same blocking response for each request. count := 10 - responses := make([]testutil.UpstreamResponse, count) + responses := make([]upstreamResponse, count) for i := range count { - responses[i] = testutil.NewFixtureResponse(fix) + responses[i] = newFixtureResponse(fix) } - upstream := testutil.NewMockUpstream(t, ctx, responses...) - - recorder := &testutil.MockRecorder{} - bridge, err := tc.configureFunc(upstream.URL, recorder, mockMCP) - require.NoError(t, err) + upstream := newMockUpstream(t, ctx, responses...) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mockMCP), + ) // Make multiple requests and verify they all have identical payloads. for range count { - req := tc.createRequestFunc(t, bridgeSrv.URL, fix.Request()) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() } // All upstream request bodies should be identical. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, count) reference := string(received[0].Body) for _, r := range received[1:] { @@ -1613,43 +1257,29 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools conditionally. - var mcpMgr mcp.ServerProxier + var mockMCP mcp.ServerProxier if tc.withInjectedTools { - mcpMgr = testutil.SetupMCPForTest(t, testTracer) + mockMCP = setupMCPForTest(t, defaultTracer) } else { - mcpMgr = testutil.NewNoopMCPManager() + mockMCP = newNoopMCPManager() } fix := fixtures.Parse(t, fixtures.AntSimple) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - - recorder := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer) - require.NoError(t, err) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMCP(mockMCP), + ) // Prepare request body with tool_choice set. reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - _ = resp.Body.Close() // Verify tool_choice in the upstream request. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) var receivedRequest map[string]any require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) @@ -1691,20 +1321,9 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Cleanup(cancel) // Create a mock server that captures the request body sent upstream. - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} - bridge, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err) - - bridgeSrv := httptest.NewUnstartedServer(bridge) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Inject adaptive thinking into the fixture request. reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) @@ -1712,15 +1331,13 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - _, _ = io.ReadAll(resp.Body) - _ = resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) // Verify the thinking field was preserved in the upstream request. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) }) @@ -1733,22 +1350,16 @@ func TestEnvironmentDoNotLeak(t *testing.T) { // Test that environment variables containing API keys/tokens are not leaked to upstream requests. // See https://github.com/coder/aibridge/issues/60. testCases := []struct { - name string - fixture []byte - configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error) - createRequest func(*testing.T, string, []byte) *http.Request - envVars map[string]string - headerName string + name string + fixture []byte + path string + envVars map[string]string + headerName string }{ { name: config.ProviderAnthropic, fixture: fixtures.AntSimple, - configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, envVars: map[string]string{ "ANTHROPIC_AUTH_TOKEN": "should-not-leak", }, @@ -1757,12 +1368,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { { name: config.ProviderOpenAI, fixture: fixtures.OaiChatSimple, - configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} - return aibridge.NewRequestBridge(t.Context(), providers, client, testutil.NewNoopMCPManager(), logger, nil, testTracer) - }, - createRequest: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, envVars: map[string]string{ "OPENAI_ORG_ID": "should-not-leak", }, @@ -1778,7 +1384,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1786,26 +1392,13 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(upstream.URL, recorderClient) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - mockSrv.Start() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - req := tc.createRequest(t, mockSrv.URL, fix.Request()) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() // Verify that environment values did not leak. - received := upstream.ReceivedRequests() + received := upstream.receivedRequests() require.Len(t, received, 1) require.Empty(t, received[0].Header.Get(tc.headerName)) }) @@ -1819,16 +1412,16 @@ func TestActorHeaders(t *testing.T) { cases := []struct { name string - createRequest createRequestFunc + path string createProviderFn func(url, key string, sendHeaders bool) aibridge.Provider fixture []byte streaming bool }{ { - name: "openai/v1/chat/completions", - createRequest: createOpenAIChatCompletionsReq, + name: "openai/v1/chat/completions", + path: pathOpenAIChatCompletions, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1836,10 +1429,10 @@ func TestActorHeaders(t *testing.T) { streaming: true, }, { - name: "openai/v1/chat/completions", - createRequest: createOpenAIChatCompletionsReq, + name: "openai/v1/chat/completions", + path: pathOpenAIChatCompletions, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1847,10 +1440,10 @@ func TestActorHeaders(t *testing.T) { streaming: false, }, { - name: "openai/v1/responses", - createRequest: createOpenAIResponsesReq, + name: "openai/v1/responses", + path: pathOpenAIResponses, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1858,10 +1451,10 @@ func TestActorHeaders(t *testing.T) { streaming: true, }, { - name: "openai/v1/responses", - createRequest: createOpenAIResponsesReq, + name: "openai/v1/responses", + path: pathOpenAIResponses, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { - cfg := openaiCfg(url, key) + cfg := openAICfg(url, key) cfg.SendActorHeaders = sendHeaders return provider.NewOpenAI(cfg) }, @@ -1869,8 +1462,8 @@ func TestActorHeaders(t *testing.T) { streaming: false, }, { - name: "anthropic/v1/messages", - createRequest: createAnthropicMessagesReq, + name: "anthropic/v1/messages", + path: pathAnthropicMessages, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := anthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -1880,8 +1473,8 @@ func TestActorHeaders(t *testing.T) { streaming: true, }, { - name: "anthropic/v1/messages", - createRequest: createAnthropicMessagesReq, + name: "anthropic/v1/messages", + path: pathAnthropicMessages, createProviderFn: func(url, key string, sendHeaders bool) aibridge.Provider { cfg := anthropicCfg(url, key) cfg.SendActorHeaders = sendHeaders @@ -1900,47 +1493,30 @@ func TestActorHeaders(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Track headers received by the upstream server. - var receivedHeaders http.Header - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.WriteHeader(http.StatusTeapot) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) - - rec := &testutil.MockRecorder{} - provider := tc.createProviderFn(srv.URL, apiKey, send) - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, testutil.NewNoopMCPManager(), logger, nil, testTracer) - require.NoError(t, err, "failed to create handler") - - mockSrv := httptest.NewUnstartedServer(b) - t.Cleanup(mockSrv.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) metadataKey := "Username" - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - // Attach an actor to the request context. - return aibcontext.AsActor(ctx, userID, recorder.Metadata{ + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withCustomProvider(tc.createProviderFn(upstream.URL, apiKey, send)), + withActor(defaultActorID, recorder.Metadata{ metadataKey: actorUsername, - }) - } - mockSrv.Start() + }), + ) // Add the stream param to the request. - reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.createRequest(t, mockSrv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + // Drain the body so streaming responses complete without + // a "connection reset" error in the mock upstream. + _, err = io.ReadAll(resp.Body) require.NoError(t, err) - require.NotEmpty(t, receivedHeaders) - defer resp.Body.Close() + + received := upstream.receivedRequests() + require.NotEmpty(t, received) + receivedHeaders := received[0].Header // Verify that the actor headers were only received if intended. found := make(map[string][]string) @@ -1952,7 +1528,7 @@ func TestActorHeaders(t *testing.T) { } if send { - require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{userID}) + require.Equal(t, found[strings.ToLower(intercept.ActorIDHeader())], []string{defaultActorID}) require.Equal(t, found[strings.ToLower(intercept.ActorMetadataHeader(metadataKey))], []string{actorUsername}) } else { require.Empty(t, found) @@ -1961,53 +1537,3 @@ func TestActorHeaders(t *testing.T) { } } } - -func calculateTotalInputTokens(in []*recorder.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Input - } - return total -} - -func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { - var total int64 - for _, el := range in { - total += el.Output - } - return total -} - -func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/anthropic/v1/messages", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req -} - -func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/chat/completions", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - - return req -} - -func openaiCfg(url, key string) config.OpenAI { - return config.OpenAI{ - BaseURL: url, - Key: key, - } -} - -func anthropicCfg(url, key string) config.Anthropic { - return config.Anthropic{ - BaseURL: url, - Key: key, - } -} diff --git a/circuit_breaker_integration_test.go b/internal/integrationtest/circuit_breaker_test.go similarity index 72% rename from circuit_breaker_integration_test.go rename to internal/integrationtest/circuit_breaker_test.go index 643b868..4e39264 100644 --- a/circuit_breaker_integration_test.go +++ b/internal/integrationtest/circuit_breaker_test.go @@ -1,10 +1,8 @@ -package aibridge_test +package integrationtest import ( - "context" "fmt" "io" - "net" "net/http" "net/http/httptest" "strings" @@ -13,18 +11,13 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/aibridge" "github.com/coder/aibridge/config" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" ) // Common response bodies for circuit breaker tests. @@ -51,8 +44,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody string successBody string requestBody string - setupHeaders func(req *http.Request) - createRequest func(t *testing.T, baseURL string, input []byte) *http.Request + headers http.Header + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -68,11 +61,11 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody: anthropicRateLimitError, successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") + headers: http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -89,10 +82,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { errorBody: openAIRateLimitError, successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer test-key") - }, - createRequest: createOpenAIChatCompletionsReq, + headers: http.Header{"Authorization": {"Bearer test-key"}}, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -127,7 +118,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { })) defer mockUpstream.Close() - metrics := metrics.NewMetrics(prometheus.NewRegistry()) + m := metrics.NewMetrics(prometheus.NewRegistry()) // Create provider with circuit breaker config cbConfig := &config.CircuitBreaker{ @@ -136,61 +127,43 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: 1, } - prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - metrics, - tracer, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), + withMetrics(m), + withActor("test-user-id", nil), ) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() - makeRequest := func() *http.Response { - req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) - tc.setupHeaders(req) - resp, err := http.DefaultClient.Do(req) + doRequest := func() *http.Response { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() return resp } // Phase 1: Trip the circuit breaker // First FailureThreshold requests hit upstream, get 429 for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load()) // Phase 2: Verify circuit is open // Request should be blocked by circuit breaker (no upstream call) - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") // Verify metrics show circuit is open - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") // Phase 3: Wait for timeout to transition to half-open @@ -201,18 +174,18 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { // Phase 4: Recovery - request in half-open state should succeed and close circuit upstreamCallsBefore := upstreamCalls.Load() - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Verify circuit is now closed - state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") // Phase 5: Verify circuit is fully functional again // Multiple requests should all succeed and reach upstream for i := 0; i < 3; i++ { - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes") } @@ -220,7 +193,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") // Rejects count should not have increased - rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects = promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") }) } @@ -235,8 +208,8 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name string errorBody string requestBody string - setupHeaders func(req *http.Request) - createRequest func(t *testing.T, baseURL string, input []byte) *http.Request + headers http.Header + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -251,11 +224,11 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { expectModel: "claude-sonnet-4-20250514", errorBody: anthropicRateLimitError, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") + headers: http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -271,10 +244,8 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { expectModel: "gpt-4o", errorBody: openAIRateLimitError, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer test-key") - }, - createRequest: createOpenAIChatCompletionsReq, + headers: http.Header{"Authorization": {"Bearer test-key"}}, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -301,7 +272,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { })) defer mockUpstream.Close() - metrics := metrics.NewMetrics(prometheus.NewRegistry()) + m := metrics.NewMetrics(prometheus.NewRegistry()) cbConfig := &config.CircuitBreaker{ FailureThreshold: 2, @@ -309,50 +280,32 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: 1, } - prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - metrics, - tracer, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), + withMetrics(m), + withActor("test-user-id", nil), ) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() - - makeRequest := func() *http.Response { - req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) - tc.setupHeaders(req) - resp, err := http.DefaultClient.Do(req) + doRequest := func() *http.Response { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() return resp } // Phase 1: Trip the circuit for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } // Verify circuit is open - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") // Phase 2: Wait for half-open state @@ -360,20 +313,20 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { // Phase 3: Request in half-open state fails, circuit should re-open upstreamCallsBefore := upstreamCalls.Load() - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Circuit should be open again - next request should be rejected immediately - resp = makeRequest() + resp = doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") // Verify metrics: trips should be 2 now (tripped twice) - trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + trips = promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + state := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") }) } @@ -389,8 +342,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody string successBody string requestBody string - setupHeaders func(req *http.Request) - createRequest func(t *testing.T, baseURL string, input []byte) *http.Request + headers http.Header + path string createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider expectProvider string expectEndpoint string @@ -406,11 +359,11 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody: anthropicRateLimitError, successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") + headers: http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, }, - createRequest: createAnthropicMessagesReq, + path: pathAnthropicMessages, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -427,10 +380,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { errorBody: openAIRateLimitError, successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, - setupHeaders: func(req *http.Request) { - req.Header.Set("Authorization", "Bearer test-key") - }, - createRequest: createOpenAIChatCompletionsReq, + headers: http.Header{"Authorization": {"Bearer test-key"}}, + path: pathOpenAIChatCompletions, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -466,7 +417,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { })) defer mockUpstream.Close() - metrics := metrics.NewMetrics(prometheus.NewRegistry()) + m := metrics.NewMetrics(prometheus.NewRegistry()) const maxRequests = 2 cbConfig := &config.CircuitBreaker{ @@ -475,47 +426,29 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: maxRequests, // Allow only 2 concurrent requests in half-open } - prov := tc.createProvider(mockUpstream.URL, cbConfig) ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - metrics, - tracer, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), + withMetrics(m), + withActor("test-user-id", nil), ) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() - - makeRequest := func() *http.Response { - req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) - tc.setupHeaders(req) - resp, err := http.DefaultClient.Do(req) + doRequest := func() *http.Response { + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() return resp } // Phase 1: Trip the circuit for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } // Verify circuit is open - resp := makeRequest() + resp := doRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) // Phase 2: Wait for half-open state and switch upstream to success @@ -532,7 +465,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp := makeRequest() + resp := doRequest() responses <- resp.StatusCode }() } @@ -562,7 +495,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { "%d requests should be rejected (ErrTooManyRequests)", totalRequests-maxRequests) // Verify rejects metric increased - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) + rejects := promtest.ToFloat64(m.CircuitBreakerRejects.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, float64(1+totalRequests-maxRequests), rejects, "CircuitBreakerRejects should include half-open rejections") }) @@ -609,54 +542,37 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { Timeout: 500 * time.Millisecond, MaxRequests: 1, } - prov := provider.NewAnthropic(config.Anthropic{ - BaseURL: mockUpstream.URL, - Key: "test-key", - CircuitBreaker: cbConfig, - }, nil) - ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []provider.Provider{prov}, - &testutil.MockRecorder{}, - testutil.NewNoopMCPManager(), - logger, - m, - tracer, + bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + withCustomProvider(provider.NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil)), + withMetrics(m), + withActor("test-user-id", nil), ) - require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() - - makeRequest := func(model string) *http.Response { + doRequest := func(model string) *http.Response { body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - req := createAnthropicMessagesReq(t, mockSrv.URL, []byte(body)) - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(resp.Body) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{ + "x-api-key": {"test"}, + "anthropic-version": {"2023-06-01"}, + }) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - resp.Body.Close() return resp } // Phase 1: Trip the circuit for sonnet model for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := makeRequest("claude-sonnet-4-20250514") + resp := doRequest("claude-sonnet-4-20250514") assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load()) // Verify sonnet circuit is open - resp := makeRequest("claude-sonnet-4-20250514") + resp := doRequest("claude-sonnet-4-20250514") assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open") assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open") @@ -668,13 +584,13 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)") // Phase 2: Haiku model should still work (independent circuit) - resp = makeRequest("claude-3-5-haiku-20241022") + resp = doRequest("claude-3-5-haiku-20241022") assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open") assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream") // Make multiple haiku requests - all should succeed for i := 0; i < 3; i++ { - resp = makeRequest("claude-3-5-haiku-20241022") + resp = doRequest("claude-3-5-haiku-20241022") assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed") } assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream") @@ -690,7 +606,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { time.Sleep(cbConfig.Timeout + 10*time.Millisecond) sonnetShouldFail.Store(false) - resp = makeRequest("claude-sonnet-4-20250514") + resp = doRequest("claude-sonnet-4-20250514") assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout") // Verify sonnet circuit is now closed diff --git a/internal/integrationtest/helpers.go b/internal/integrationtest/helpers.go new file mode 100644 index 0000000..84bd64d --- /dev/null +++ b/internal/integrationtest/helpers.go @@ -0,0 +1,55 @@ +package integrationtest + +import ( + "testing" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/config" +) + +// anthropicCfg creates a minimal Anthropic config for testing. +func anthropicCfg(url string, key string) config.Anthropic { + return config.Anthropic{ + BaseURL: url, + Key: key, + } +} + +func anthropicCfgWithAPIDump(url string, key string, dumpDir string) config.Anthropic { + cfg := anthropicCfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + +// bedrockCfg returns a test AWS Bedrock config pointing at the given URL. +func bedrockCfg(url string) *config.AWSBedrock { + return &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: "beddel", // This model should override the request's given one. + SmallFastModel: "modrock", // Unused but needed for validation. + BaseURL: url, + } +} + +// openAICfg creates a minimal OpenAI config for testing. +func openAICfg(url string, key string) config.OpenAI { + return config.OpenAI{ + BaseURL: url, + Key: key, + } +} + +func openaiCfgWithAPIDump(url string, key string, dumpDir string) config.OpenAI { + cfg := openAICfg(url, key) + cfg.APIDumpDir = dumpDir + return cfg +} + +// newLogger creates a test logger at Debug level. +func newLogger(t *testing.T) slog.Logger { + t.Helper() + return slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) +} diff --git a/metrics_integration_test.go b/internal/integrationtest/metrics_test.go similarity index 54% rename from metrics_integration_test.go rename to internal/integrationtest/metrics_test.go index ac0d42a..23e3847 100644 --- a/metrics_integration_test.go +++ b/internal/integrationtest/metrics_test.go @@ -1,27 +1,21 @@ -package aibridge_test +package integrationtest import ( + "bytes" "context" "io" - "net" "net/http" "net/http/httptest" "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/metrics" - "github.com/coder/aibridge/provider" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel/trace" ) func TestMetrics_Interception(t *testing.T) { @@ -30,7 +24,7 @@ func TestMetrics_Interception(t *testing.T) { cases := []struct { name string fixture []byte - reqFunc func(*testing.T, string, []byte) *http.Request + path string expectStatus string expectModel string expectRoute string @@ -40,7 +34,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_simple", fixture: fixtures.AntSimple, - reqFunc: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -49,7 +43,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "ant_error", fixture: fixtures.AntNonStreamError, - reqFunc: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", @@ -59,7 +53,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_simple", fixture: fixtures.OaiChatSimple, - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -68,7 +62,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_chat_error", fixture: fixtures.OaiChatNonStreamError, - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", @@ -78,7 +72,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_simple", fixture: fixtures.OaiResponsesBlockingSimple, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -87,7 +81,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_blocking_error", fixture: fixtures.OaiResponsesBlockingHttpErr, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -97,7 +91,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusCompleted, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -106,7 +100,7 @@ func TestMetrics_Interception(t *testing.T) { { name: "oai_responses_streaming_error", fixture: fixtures.OaiResponsesStreamingHttpErr, - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectStatus: metrics.InterceptionCountStatusFailed, expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", @@ -123,29 +117,23 @@ func TestMetrics_Interception(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) upstream.AllowOverflow = tc.allowOverflow - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - var prov aibridge.Provider - if tc.expectProvider == config.ProviderAnthropic { - prov = provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) - } else { - prov = provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - } - srv, _ := newTestSrv(t, ctx, prov, metrics, testTracer) - - req := tc.reqFunc(t, srv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMetrics(m), + ) + + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + _, err := io.ReadAll(resp.Body) require.NoError(t, err) - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - count := promtest.ToFloat64(metrics.InterceptionCount.WithLabelValues( - tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", userID)) + count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( + tc.expectProvider, tc.expectModel, tc.expectStatus, tc.expectRoute, "POST", defaultActorID)) require.Equal(t, 1.0, count) - require.Equal(t, 1, promtest.CollectAndCount(metrics.InterceptionDuration)) - require.Equal(t, 1, promtest.CollectAndCount(metrics.InterceptionCount)) + require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionDuration)) + require.Equal(t, 1, promtest.CollectAndCount(m.InterceptionCount)) }) } } @@ -166,26 +154,29 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { })) t.Cleanup(srv.Close) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil) - bridgeSrv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + bridgeServer := newBridgeTestServer(t, ctx, srv.URL, + withMetrics(m), + ) // Make request in background. doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, fix.Request()) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, bridgeServer.URL+pathAnthropicMessages, bytes.NewReader(fix.Request())) + req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) } }() // Wait until request is detected as inflight. require.Eventually(t, func() bool { return promtest.ToFloat64( - metrics.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 1 }, time.Second*10, time.Millisecond*50) @@ -200,7 +191,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { // Metric is not updated immediately after request completes, so wait until it is. require.Eventually(t, func() bool { return promtest.ToFloat64( - metrics.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), + m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 0 }, time.Second*10, time.Millisecond*50) } @@ -211,19 +202,15 @@ func TestMetrics_PassthroughCount(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) t.Cleanup(upstream.Close) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, t.Context(), provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, + withMetrics(m), + ) - req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) - require.NoError(t, err) - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) require.Equal(t, http.StatusOK, resp.StatusCode) - count := promtest.ToFloat64(metrics.PassthroughCount.WithLabelValues( + count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues( config.ProviderOpenAI, "/models", "GET")) require.Equal(t, 1.0, count) } @@ -235,21 +222,20 @@ func TestMetrics_PromptCount(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSimple) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMetrics(m), + ) - req := createOpenAIChatCompletionsReq(t, srv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) - prompts := promtest.ToFloat64(metrics.PromptCount.WithLabelValues( - config.ProviderOpenAI, "gpt-4.1", userID)) + prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( + config.ProviderOpenAI, "gpt-4.1", defaultActorID)) require.Equal(t, 1.0, prompts) } @@ -260,20 +246,19 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) + m := aibridge.NewMetrics(prometheus.NewRegistry()) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMetrics(m), + ) - req := createOpenAIChatCompletionsReq(t, srv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) - count := promtest.ToFloat64(metrics.NonInjectedToolUseCount.WithLabelValues( + count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues( config.ProviderOpenAI, "gpt-4.1", "read_file")) require.Equal(t, 1.0, count) } @@ -286,67 +271,35 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // First request returns the tool invocation, the second returns the mocked response to the tool result. fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) - recorder := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) + m := aibridge.NewMetrics(prometheus.NewRegistry()) // Setup mocked MCP server & tools. - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := setupMCPForTest(t, defaultTracer) - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mockMCP, logger, metrics, testTracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withMetrics(m), + withMCP(mockMCP), + ) - req := createAnthropicMessagesReq(t, srv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) // Wait until full roundtrip has completed. require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) + recorder := bridgeServer.Recorder require.Len(t, recorder.ToolUsages(), 1) require.True(t, recorder.ToolUsages()[0].Injected) require.NotNil(t, recorder.ToolUsages()[0].ServerURL) actualServerURL := *recorder.ToolUsages()[0].ServerURL - count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues( - config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, testutil.MockToolName)) + count := promtest.ToFloat64(m.InjectedToolUseCount.WithLabelValues( + config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) require.Equal(t, 1.0, count) } - -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *metrics.Metrics, tracer trace.Tracer) (*httptest.Server, *testutil.MockRecorder) { - t.Helper() - - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - mockRecorder := &testutil.MockRecorder{} - clientFn := func() (aibridge.Recorder, error) { - return mockRecorder, nil - } - wrappedRecorder := aibridge.NewRecorder(logger, tracer, clientFn) - - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, wrappedRecorder, testutil.NewNoopMCPManager(), logger, metrics, tracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) - - return srv, mockRecorder -} diff --git a/internal/testutil/mockmcp.go b/internal/integrationtest/mockmcp.go similarity index 79% rename from internal/testutil/mockmcp.go rename to internal/integrationtest/mockmcp.go index 212e400..eba25dd 100644 --- a/internal/testutil/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -1,4 +1,4 @@ -package testutil +package integrationtest import ( "context" @@ -21,33 +21,33 @@ import ( "go.opentelemetry.io/otel/trace/noop" ) -// MockToolName is the primary mock tool name used in MCP tests. -const MockToolName = "coder_list_workspaces" +// mockToolName is the primary mock tool name used in MCP tests. +const mockToolName = "coder_list_workspaces" -// MockMCP wraps a real mcp.ServerProxier with test assertion helpers. +// mockMCP wraps a real mcp.ServerProxier with test assertion helpers. // Implements mcp.ServerProxier so it can be passed directly to NewRequestBridge. -type MockMCP struct { +type mockMCP struct { mcp.ServerProxier calls *callAccumulator } -// GetCallsByTool returns recorded arguments for a given tool name. -func (m *MockMCP) GetCallsByTool(name string) []any { +// getCallsByTool returns recorded arguments for a given tool name. +func (m *mockMCP) getCallsByTool(name string) []any { return m.calls.getCallsByTool(name) } -// SetToolError configures a tool to return an error when invoked. -func (m *MockMCP) SetToolError(tool, errMsg string) { +// setToolError configures a tool to return an error when invoked. +func (m *mockMCP) setToolError(tool, errMsg string) { m.calls.setToolError(tool, errMsg) } -// SetupMCPForTest creates a ready-to-use MCP server with proxy named "coder". -func SetupMCPForTest(t *testing.T, tracer trace.Tracer) *MockMCP { +// setupMCPForTest creates a ready-to-use MCP server with proxy named "coder". +func setupMCPForTest(t *testing.T, tracer trace.Tracer) *mockMCP { t.Helper() - return SetupMCPForTestWithName(t, "coder", tracer) + return setupMCPForTestWithName(t, "coder", tracer) } -func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *MockMCP { +func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mockMCP { t.Helper() srv, acc := createMockMCPSrv(t) @@ -59,7 +59,9 @@ func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *Mo // which can break when httptest.Server calls CloseIdleConnections in parallel // resulting in error `init MCP client: failed to send initialized notification: failed to send request: failed to send request: Post "http://127.0.0.1:43843": net/http: HTTP/1.x transport connection broken: http: CloseIdleConnections called` // https://github.com/golang/go/blob/44ec057a3e89482cf775f5eaaf03b0b5fcab1fa4/src/net/http/httptest/server.go#L268 - httpClient := &http.Client{Transport: &http.Transport{}} + httpTransport := &http.Transport{} + t.Cleanup(httpTransport.CloseIdleConnections) + httpClient := &http.Client{Transport: httpTransport} proxy, err := mcp.NewStreamableHTTPServerProxy(name, mcpSrv.URL, nil, nil, nil, logger, tracer, transport.WithHTTPBasicClient(httpClient)) require.NoError(t, err) @@ -75,10 +77,10 @@ func SetupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *Mo require.NoError(t, mgr.Init(ctx)) require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init") - return &MockMCP{ServerProxier: mgr, calls: acc} + return &mockMCP{ServerProxier: mgr, calls: acc} } -func NewNoopMCPManager() mcp.ServerProxier { +func newNoopMCPManager() mcp.ServerProxier { return mcp.NewServerProxyManager(nil, noop.NewTracerProvider().Tracer("")) } @@ -134,7 +136,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { acc := newCallAccumulator() - for _, name := range []string{MockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} { + for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build", "coder_delete_template"} { tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), ) diff --git a/internal/testutil/upstream.go b/internal/integrationtest/mockupstream.go similarity index 83% rename from internal/testutil/upstream.go rename to internal/integrationtest/mockupstream.go index bb935a8..a658b05 100644 --- a/internal/testutil/upstream.go +++ b/internal/integrationtest/mockupstream.go @@ -1,4 +1,4 @@ -package testutil +package integrationtest import ( "bufio" @@ -23,10 +23,10 @@ import ( "github.com/tidwall/gjson" ) -// UpstreamResponse defines a single response that MockUpstream will replay -// for one incoming request. Use [NewFixtureResponse] or [NewFixtureToolResponse] to +// upstreamResponse defines a single response that mockUpstream will replay +// for one incoming request. Use [newFixtureResponse] or [newFixtureToolResponse] to // construct one from a parsed txtar archive. -type UpstreamResponse struct { +type upstreamResponse struct { Streaming []byte // returned when the request has "stream": true. Blocking []byte // returned for non-streaming requests. @@ -35,11 +35,11 @@ type UpstreamResponse struct { OnRequest func(r *http.Request, body []byte) } -// NewFixtureResponse creates an UpstreamResponse from a parsed fixture archive. +// newFixtureResponse creates an upstreamResponse from a parsed fixture archive. // It reads whichever of 'streaming' and 'non-streaming' sections exist; // not every fixture has both (e.g. error fixtures may only define one). -func NewFixtureResponse(fix fixtures.Fixture) UpstreamResponse { - var resp UpstreamResponse +func newFixtureResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse if fix.Has(fixtures.SectionStreaming) { resp.Streaming = fix.Streaming() } @@ -49,11 +49,11 @@ func NewFixtureResponse(fix fixtures.Fixture) UpstreamResponse { return resp } -// NewFixtureToolResponse creates an UpstreamResponse from the tool-call fixture files. +// newFixtureToolResponse creates an upstreamResponse from the tool-call fixture files. // It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' // sections exist. -func NewFixtureToolResponse(fix fixtures.Fixture) UpstreamResponse { - var resp UpstreamResponse +func newFixtureToolResponse(fix fixtures.Fixture) upstreamResponse { + var resp upstreamResponse if fix.Has(fixtures.SectionStreamingToolCall) { resp.Streaming = fix.StreamingToolCall() } @@ -63,18 +63,18 @@ func NewFixtureToolResponse(fix fixtures.Fixture) UpstreamResponse { return resp } -// ReceivedRequest captures the details of a single request handled by MockUpstream. -type ReceivedRequest struct { +// receivedRequest captures the details of a single request handled by mockUpstream. +type receivedRequest struct { Method string Path string Header http.Header Body []byte } -// MockUpstream replays txtar fixture responses, validates incoming request +// mockUpstream replays txtar fixture responses, validates incoming request // bodies, and counts calls. It stands in for a real AI provider API // (Anthropic, OpenAI) during integration tests. -type MockUpstream struct { +type mockUpstream struct { *httptest.Server // Calls is incremented atomically on every request. @@ -92,31 +92,31 @@ type MockUpstream struct { AllowOverflow bool mu sync.Mutex - requests []ReceivedRequest + requests []receivedRequest t *testing.T - responses []UpstreamResponse + responses []upstreamResponse } -// ReceivedRequests returns a copy of all requests received so far. -func (ms *MockUpstream) ReceivedRequests() []ReceivedRequest { +// receivedRequests returns a copy of all requests received so far. +func (ms *mockUpstream) receivedRequests() []receivedRequest { ms.mu.Lock() defer ms.mu.Unlock() - return append([]ReceivedRequest(nil), ms.requests...) + return append([]receivedRequest(nil), ms.requests...) } -// NewMockUpstream creates a started httptest.Server that replays fixture +// newMockUpstream creates a started httptest.Server that replays fixture // responses. Responses are returned in order: first call → first response. // The test fails if the number of requests doesn't match the number of // responses (when AllowOverflow is not set, default). // -// srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) // simple -// srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) // multi-turn -func NewMockUpstream(t *testing.T, ctx context.Context, responses ...UpstreamResponse) *MockUpstream { +// srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // simple +// srv := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn +func newMockUpstream(t *testing.T, ctx context.Context, responses ...upstreamResponse) *mockUpstream { t.Helper() - require.NotEmpty(t, responses, "at least one UpstreamResponse required") + require.NotEmpty(t, responses, "at least one upstreamResponse required") - ms := &MockUpstream{ + ms := &mockUpstream{ t: t, responses: responses, } @@ -141,7 +141,7 @@ func NewMockUpstream(t *testing.T, ctx context.Context, responses ...UpstreamRes return ms } -func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { +func (ms *mockUpstream) handle(w http.ResponseWriter, r *http.Request) { call := int(ms.Calls.Add(1) - 1) body, err := io.ReadAll(r.Body) @@ -149,7 +149,7 @@ func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { require.NoError(ms.t, err) ms.mu.Lock() - ms.requests = append(ms.requests, ReceivedRequest{ + ms.requests = append(ms.requests, receivedRequest{ Method: r.Method, Path: r.URL.Path, Header: r.Header.Clone(), @@ -186,7 +186,7 @@ func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(resp.Blocking) } -func (ms *MockUpstream) responseForCall(call int) UpstreamResponse { +func (ms *mockUpstream) responseForCall(call int) upstreamResponse { if call >= len(ms.responses) { if ms.AllowOverflow { return ms.responses[len(ms.responses)-1] @@ -203,7 +203,7 @@ func isStreaming(body []byte, urlPath string) bool { return gjson.GetBytes(body, "stream").Bool() || strings.HasSuffix(urlPath, "invoke-with-response-stream") } -func (ms *MockUpstream) writeSSE(w http.ResponseWriter, data []byte) { +func (ms *mockUpstream) writeSSE(w http.ResponseWriter, data []byte) { ms.t.Helper() w.Header().Set("Content-Type", "text/event-stream") @@ -237,7 +237,7 @@ func isRawHTTPResponse(data []byte) bool { // writeRawHTTPResponse parses data as a complete HTTP response and replays it, // copying the status code, headers, and body to w. This supports error fixtures // that contain full HTTP responses (e.g. "HTTP/2.0 400 Bad Request\r\n..."). -func (ms *MockUpstream) writeRawHTTPResponse(w http.ResponseWriter, r *http.Request, data []byte) { +func (ms *mockUpstream) writeRawHTTPResponse(w http.ResponseWriter, r *http.Request, data []byte) { ms.t.Helper() resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), r) diff --git a/responses_integration_test.go b/internal/integrationtest/responses_test.go similarity index 86% rename from responses_integration_test.go rename to internal/integrationtest/responses_test.go index 7a26fff..483ce90 100644 --- a/responses_integration_test.go +++ b/internal/integrationtest/responses_test.go @@ -1,7 +1,6 @@ -package aibridge_test +package integrationtest import ( - "bytes" "context" "encoding/json" "io" @@ -14,13 +13,9 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" - aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3/responses" @@ -335,22 +330,13 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - ctx = aibcontext.AsActor(ctx, userID, nil) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, mockRecorder := newTestSrv(t, ctx, provider, nil, testTracer) - defer srv.Close() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - req := createOpenAIResponsesReq(t, srv.URL, fix.Request()) - req.Header.Set("User-Agent", tc.userAgent) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}}) require.Equal(t, http.StatusOK, resp.StatusCode) got, err := io.ReadAll(resp.Body) @@ -361,16 +347,16 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Equal(t, string(fix.NonStreaming()), string(got)) } - interceptions := mockRecorder.RecordedInterceptions() + interceptions := bridgeServer.Recorder.RecordedInterceptions() require.Len(t, interceptions, 1) intc := interceptions[0] - require.Equal(t, intc.InitiatorID, userID) + require.Equal(t, intc.InitiatorID, defaultActorID) require.Equal(t, intc.Provider, config.ProviderOpenAI) require.Equal(t, intc.Model, tc.expectModel) require.Equal(t, tc.userAgent, intc.UserAgent) require.Equal(t, string(tc.expectedClient), intc.Client) - recordedPrompts := mockRecorder.RecordedPromptUsages() + recordedPrompts := bridgeServer.Recorder.RecordedPromptUsages() if tc.expectPromptRecorded != "" { require.Len(t, recordedPrompts, 1) promptEq := func(pur *recorder.PromptUsageRecord) bool { return pur.Prompt == tc.expectPromptRecorded } @@ -379,7 +365,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Empty(t, recordedPrompts) } - recordedTools := mockRecorder.RecordedToolUsages() + recordedTools := bridgeServer.Recorder.RecordedToolUsages() if tc.expectToolRecorded != nil { require.Len(t, recordedTools, 1) recordedTools[0].InterceptionID = tc.expectToolRecorded.InterceptionID // ignore interception id (interception id is not constant and response doesn't contain it) @@ -389,7 +375,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { require.Empty(t, recordedTools) } - recordedTokens := mockRecorder.RecordedTokenUsages() + recordedTokens := bridgeServer.Recorder.RecordedTokenUsages() if tc.expectTokenUsage != nil { require.Len(t, recordedTokens, 1) recordedTokens[0].InterceptionID = tc.expectTokenUsage.InterceptionID // ignore interception id @@ -433,18 +419,11 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) - req := createOpenAIResponsesReq(t, srv.URL, reqBytes) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) require.Equal(t, "application/json", resp.Header.Get("Content-Type")) require.Equal(t, http.StatusNotImplemented, resp.StatusCode) @@ -527,38 +506,30 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - defer r.Body.Close() - - var receivedRequest map[string]any - require.NoError(t, json.Unmarshal(raw, &receivedRequest)) - if tc.expectParallelToolCalls { - parallelToolCalls, ok := receivedRequest["parallel_tool_calls"].(bool) - require.True(t, ok, "parallel_tool_calls should be present in upstream request") - require.Equal(t, tc.expectParallelToolCallsValue, parallelToolCalls) - } else { - _, ok := receivedRequest["parallel_tool_calls"] - require.False(t, ok, "parallel_tool_calls should not be present when not set") - } - - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(upstream.Close) + fix := fixtures.OaiResponsesBlockingSimple + if tc.streaming { + fix = fixtures.OaiResponsesStreamingSimple + } + upstream := newMockUpstream(t, ctx, newFixtureResponse(fixtures.Parse(t, fix))) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, []byte(tc.request)) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) - req := createOpenAIResponsesReq(t, srv.URL, []byte(tc.request)) - client := &http.Client{} + received := upstream.receivedRequests() + require.Len(t, received, 1) - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - require.NoError(t, err) + var receivedRequest map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) + if tc.expectParallelToolCalls { + parallelToolCalls, ok := receivedRequest["parallel_tool_calls"].(bool) + require.True(t, ok, "parallel_tool_calls should be present in upstream request") + require.Equal(t, tc.expectParallelToolCallsValue, parallelToolCalls) + } else { + _, ok := receivedRequest["parallel_tool_calls"] + require.False(t, ok, "parallel_tool_calls should not be present when not set") + } }) } } @@ -608,17 +579,11 @@ func TestClientAndConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - prov := provider.NewOpenAI(openaiCfg(tc.addr, apiKey)) - srv, mockRecorder := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + // tc.addr may be an intentionally invalid URL; use withCustomProvider. + bridgeServer := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) reqBytes := responsesRequestBytes(t, tc.streaming) - req := createOpenAIResponsesReq(t, srv.URL, reqBytes) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) require.Equal(t, "application/json", resp.Header.Get("Content-Type")) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -626,7 +591,7 @@ func TestClientAndConnectionError(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) requireResponsesError(t, http.StatusInternalServerError, tc.errContains, body) - require.Empty(t, mockRecorder.RecordedPromptUsages()) + require.Empty(t, bridgeServer.Recorder.RecordedPromptUsages()) }) } } @@ -692,17 +657,10 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, ctx, prov, nil, testTracer) - defer srv.Close() + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) reqBytes := responsesRequestBytes(t, tc.streaming) - req := createOpenAIResponsesReq(t, srv.URL, reqBytes) - client := &http.Client{} - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) require.Equal(t, tc.statusCode, resp.StatusCode) require.Equal(t, tc.contentType, resp.Header.Get("Content-Type")) @@ -868,32 +826,17 @@ func TestResponsesInjectedTool(t *testing.T) { // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // Setup MCP server proxies (with mock tools). - mockMCP := testutil.SetupMCPForTest(t, testTracer) + mockMCP := setupMCPForTest(t, defaultTracer) if tc.expectToolError != "" { - mockMCP.SetToolError(tc.mcpToolName, tc.expectToolError) + mockMCP.setToolError(tc.mcpToolName, tc.expectToolError) } - prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - mockRecorder := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - - bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{prov}, mockRecorder, mockMCP, logger, nil, testTracer) - require.NoError(t, err) - - srv := httptest.NewUnstartedServer(bridge) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibcontext.AsActor(ctx, userID, nil) - } - srv.Start() - t.Cleanup(srv.Close) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP)) - req := createOpenAIResponsesReq(t, srv.URL, fix.Request()) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -905,11 +848,11 @@ func TestResponsesInjectedTool(t *testing.T) { }, time.Second*10, time.Millisecond*50) // Verify the injected tool was invoked via MCP. - invocations := mockMCP.GetCallsByTool(tc.mcpToolName) + invocations := mockMCP.getCallsByTool(tc.mcpToolName) require.Len(t, invocations, 1, "expected MCP tool to be invoked once") // Verify the injected tool usage was recorded. - toolUsages := mockRecorder.RecordedToolUsages() + toolUsages := bridgeServer.Recorder.RecordedToolUsages() require.Len(t, toolUsages, 1) require.Equal(t, tc.mcpToolName, toolUsages[0].Tool) require.Equal(t, tc.expectToolArgs, toolUsages[0].Args) @@ -919,11 +862,11 @@ func TestResponsesInjectedTool(t *testing.T) { } // Verify prompt was recorded. - prompts := mockRecorder.RecordedPromptUsages() + prompts := bridgeServer.Recorder.RecordedPromptUsages() require.Len(t, prompts, 1) require.Equal(t, tc.expectPrompt, prompts[0].Prompt) - tokenUsages := mockRecorder.RecordedTokenUsages() + tokenUsages := bridgeServer.Recorder.RecordedTokenUsages() require.Len(t, tokenUsages, len(tc.expectTokenUsages)) for i := range tokenUsages { tokenUsages[i].InterceptionID = "" // ignore interception ID and time creation when comparing @@ -941,15 +884,6 @@ func TestResponsesInjectedTool(t *testing.T) { } } -func createOpenAIResponsesReq(t *testing.T, baseURL string, input []byte) *http.Request { - t.Helper() - - req, err := http.NewRequestWithContext(t.Context(), "POST", baseURL+"/openai/v1/responses", bytes.NewReader(input)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - return req -} - func requireResponsesError(t *testing.T, code int, message string, body []byte) { var respErr responses.Error err := json.Unmarshal(body, &respErr) @@ -999,6 +933,10 @@ func startRejectingListener(t *testing.T) (addr string) { return } + // Read at least 1 byte so the client has started writing + // before we RST, ensuring a consistent "connection reset by peer". + buf := make([]byte, 1) + _, _ = c.Read(buf) if tc, ok := c.(*net.TCPConn); ok { _ = tc.SetLinger(0) } diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go new file mode 100644 index 0000000..1d061c2 --- /dev/null +++ b/internal/integrationtest/setupbridge.go @@ -0,0 +1,259 @@ +package integrationtest + +import ( + "bytes" + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge" + "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/metrics" + "github.com/coder/aibridge/provider" + "github.com/coder/aibridge/recorder" + "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" +) + +const ( + pathAnthropicMessages = "/anthropic/v1/messages" + pathOpenAIChatCompletions = "/openai/v1/chat/completions" + pathOpenAIResponses = "/openai/v1/responses" + + // providerBedrock identifies a Bedrock provider in [withProvider]. + // other providers use config.Provider* constants. + providerBedrock = "bedrock" + + // defaults + apiKey = "api-key" + defaultActorID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" +) + +var defaultTracer = otel.Tracer("integrationtest") + +type bridgeConfig struct { + providerBuilders []func(upstreamURL string) aibridge.Provider + metrics *metrics.Metrics + tracer trace.Tracer + mcpProxy mcp.ServerProxier + userID string + metadata recorder.Metadata + logger slog.Logger +} + +// bridgeTestServer wraps an httptest.Server running a RequestBridge. +type bridgeTestServer struct { + *httptest.Server + Recorder *testutil.MockRecorder + Bridge *aibridge.RequestBridge +} + +// makeRequest builds and executes an HTTP request against this server. +// Optional headers are applied after the default Content-Type. +func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) *http.Response { + t.Helper() + + req, err := http.NewRequestWithContext(t.Context(), method, s.URL+path, bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + for _, h := range header { + for k, vals := range h { + for _, v := range vals { + req.Header.Add(k, v) + } + } + } + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Body.Close() }) + return resp +} + +type bridgeOption func(*bridgeConfig) + +// withProvider adds a default-configured provider of the given type. +// When any provider option is used, the default "all providers" set is not created. +func withProvider(providerType string) bridgeOption { + return func(c *bridgeConfig) { + c.providerBuilders = append(c.providerBuilders, func(addr string) aibridge.Provider { + return newDefaultProvider(providerType, addr) + }) + } +} + +// withCustomProvider adds a pre-built provider. The upstream URL passed to +// [newBridgeTestServer] is ignored for this provider. +// When any provider option is used, the default "all providers" set is not created. +func withCustomProvider(p aibridge.Provider) bridgeOption { + return func(c *bridgeConfig) { + c.providerBuilders = append(c.providerBuilders, func(string) aibridge.Provider { + return p + }) + } +} + +// withMetrics sets the Prometheus metrics for the bridge. +func withMetrics(m *metrics.Metrics) bridgeOption { + return func(c *bridgeConfig) { c.metrics = m } +} + +// withTracer overrides the default tracer. +func withTracer(t trace.Tracer) bridgeOption { + return func(c *bridgeConfig) { c.tracer = t } +} + +// withMCP sets the MCP server proxier (default: NoopMCPManager). +func withMCP(p mcp.ServerProxier) bridgeOption { + return func(c *bridgeConfig) { c.mcpProxy = p } +} + +// withActor sets the actor ID and metadata for the BaseContext. +func withActor(id string, md recorder.Metadata) bridgeOption { + return func(c *bridgeConfig) { c.userID = id; c.metadata = md } +} + +// newBridgeTestServer creates a fully configured test server running +// a RequestBridge with sensible defaults: +// - All standard providers (unless withProvider / withCustomProvider) +// - NoopMCPManager (unless withMCP) +// - slogtest debug logger +// - defaultTracer (unless withTracer) +// - defaultActorID (unless withActor) +func newBridgeTestServer( + t *testing.T, + ctx context.Context, + upstreamURL string, + opts ...bridgeOption, +) *bridgeTestServer { + t.Helper() + + cfg := &bridgeConfig{ + userID: defaultActorID, + } + for _, o := range opts { + o(cfg) + } + if cfg.tracer == nil { + cfg.tracer = defaultTracer + } + cfg.logger = newLogger(t) + if cfg.mcpProxy == nil { + cfg.mcpProxy = newNoopMCPManager() + } + + // Resolve providers: use explicit builders when provided, otherwise + // create default providers for every supported type. + var providers []aibridge.Provider + if len(cfg.providerBuilders) > 0 { + for _, b := range cfg.providerBuilders { + providers = append(providers, b(upstreamURL)) + } + } else { + providers = []aibridge.Provider{ + newDefaultProvider(config.ProviderAnthropic, upstreamURL), + newDefaultProvider(config.ProviderOpenAI, upstreamURL), + } + } + + mockRec := &testutil.MockRecorder{} + rec := aibridge.NewRecorder(cfg.logger, cfg.tracer, func() (aibridge.Recorder, error) { + return mockRec, nil + }) + + bridge, err := aibridge.NewRequestBridge( + ctx, providers, rec, cfg.mcpProxy, + cfg.logger, cfg.metrics, cfg.tracer, + ) + require.NoError(t, err) + + actorID, md := cfg.userID, cfg.metadata + srv := httptest.NewUnstartedServer(bridge) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, actorID, md) + } + srv.Start() + t.Cleanup(srv.Close) + + return &bridgeTestServer{ + Server: srv, + Recorder: mockRec, + Bridge: bridge, + } +} + +// setupInjectedToolTest abstracts common setup required for injected-tool integration tests. +// Extra bridge options (e.g. [withProvider]) are appended after the built-in +// MCP / tracer / actor options. When no provider option is given the default +// provider set (all providers) is used. +func setupInjectedToolTest( + t *testing.T, + fixture []byte, + streaming bool, + tracer trace.Tracer, + path string, + toolRequestValidatorFn func(*http.Request, []byte), + opts ...bridgeOption, +) (*bridgeTestServer, *mockMCP, *http.Response) { + t.Helper() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixture) + + // Setup mock server for multi-turn interaction. + // First request → tool call response + // Second request → final response. + firstResp := newFixtureResponse(fix) + toolResp := newFixtureToolResponse(fix) + toolResp.OnRequest = toolRequestValidatorFn + upstream := newMockUpstream(t, ctx, firstResp, toolResp) + + mockMCP := setupMCPForTest(t, tracer) + + allOpts := []bridgeOption{ + withMCP(mockMCP), + withTracer(tracer), + withActor(defaultActorID, nil), + } + allOpts = append(allOpts, opts...) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) + + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + + resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Wait both requests (initial + tool call result) + require.Eventually(t, func() bool { + return upstream.Calls.Load() == 2 + }, time.Second*10, time.Millisecond*50) + + return bridgeServer, mockMCP, resp +} + +// newDefaultProvider creates a Provider with default test configuration. +func newDefaultProvider(providerType string, addr string) aibridge.Provider { + switch providerType { + case config.ProviderAnthropic: + return provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) + case config.ProviderOpenAI: + return provider.NewOpenAI(openAICfg(addr, apiKey)) + case providerBedrock: + return provider.NewAnthropic(anthropicCfg(addr, apiKey), bedrockCfg(addr)) + default: + panic("unknown provider type: " + providerType) + } +} diff --git a/trace_integration_test.go b/internal/integrationtest/trace_test.go similarity index 69% rename from trace_integration_test.go rename to internal/integrationtest/trace_test.go index a62e58e..88bec31 100644 --- a/trace_integration_test.go +++ b/internal/integrationtest/trace_test.go @@ -1,8 +1,7 @@ -package aibridge_test +package integrationtest import ( "context" - "fmt" "net/http" "slices" "strings" @@ -11,8 +10,6 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" - "github.com/coder/aibridge/internal/testutil" - "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,6 +19,7 @@ import ( "go.opentelemetry.io/otel/codes" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" + oteltrace "go.opentelemetry.io/otel/trace" ) // expect 'count' amount of traces named 'name' with status 'status' @@ -31,6 +29,18 @@ type expectTrace struct { status codes.Code } +func setupTracer(t *testing.T) (*tracetest.SpanRecorder, oteltrace.Tracer) { + t.Helper() + + sr := tracetest.NewSpanRecorder() + tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) + t.Cleanup(func() { + _ = tp.Shutdown(t.Context()) + }) + + return sr, tp.Tracer(t.Name()) +} + func TestTraceAnthropic(t *testing.T) { expectNonStreaming := []expectTrace{ {"Intercept", 1, codes.Unset}, @@ -89,36 +99,30 @@ func TestTraceAnthropic(t *testing.T) { fixtureReqBody := fix.Request() for _, tc := range cases { - t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - var bedrockCfg *config.AWSBedrock + opts := []bridgeOption{ + withTracer(tracer), + } if tc.bedrock { - bedrockCfg = testBedrockCfg(upstream.URL) + opts = append(opts, withProvider(providerBedrock)) } - provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) - srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fixtureReqBody, "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - srv.Close() + bridgeServer.Close() - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID model := gjson.Get(string(reqBody), "model").Str if tc.bedrock { @@ -135,7 +139,7 @@ func TestTraceAnthropic(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderAnthropic), attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), attribute.Bool(tracing.IsBedrock, tc.bedrock), } @@ -208,37 +212,31 @@ func TestTraceAnthropicErr(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - var bedrockCfg *config.AWSBedrock + opts := []bridgeOption{ + withTracer(tracer), + } if tc.bedrock { - bedrockCfg = testBedrockCfg(upstream.URL) + opts = append(opts, withProvider(providerBedrock)) } - provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) - srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := createAnthropicMessagesReq(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) if tc.streaming { require.Equal(t, http.StatusOK, resp.StatusCode) } else { require.Equal(t, tc.expectCode, resp.StatusCode) } - defer resp.Body.Close() - srv.Close() + bridgeServer.Close() - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -259,7 +257,7 @@ func TestTraceAnthropicErr(t *testing.T) { attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderAnthropic), attribute.String(tracing.Model, model), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), attribute.Bool(tracing.IsBedrock, tc.bedrock), } @@ -277,30 +275,25 @@ func TestInjectedToolsTrace(t *testing.T) { streaming bool bedrock bool fixture []byte - providerFn providerFunc - createReqFn func(*testing.T, string, []byte) *http.Request + path string expectModel string - expectPath string expectProvider string + opts []bridgeOption }{ { name: "anthr_blocking", streaming: false, fixture: fixtures.AntSingleInjectedTool, - providerFn: newAnthropicProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "claude-sonnet-4-20250514", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, }, { name: "anthr_streaming", streaming: true, fixture: fixtures.AntSingleInjectedTool, - providerFn: newAnthropicProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "claude-sonnet-4-20250514", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, }, { @@ -308,41 +301,35 @@ func TestInjectedToolsTrace(t *testing.T) { streaming: false, bedrock: true, fixture: fixtures.AntSingleInjectedTool, - providerFn: newBedrockProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "beddel", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, + opts: []bridgeOption{withProvider(providerBedrock)}, }, { name: "bedrock_streaming", streaming: true, bedrock: true, fixture: fixtures.AntSingleInjectedTool, - providerFn: newBedrockProvider, - createReqFn: createAnthropicMessagesReq, + path: pathAnthropicMessages, expectModel: "beddel", - expectPath: "/anthropic/v1/messages", expectProvider: config.ProviderAnthropic, + opts: []bridgeOption{withProvider(providerBedrock)}, }, { name: "openai_blocking", streaming: false, fixture: fixtures.OaiChatSingleInjectedTool, - providerFn: newOpenAIProvider, - createReqFn: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectModel: "gpt-4.1", - expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, }, { name: "openai_streaming", streaming: true, fixture: fixtures.OaiChatSingleInjectedTool, - providerFn: newOpenAIProvider, - createReqFn: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectModel: "gpt-4.1", - expectPath: "/openai/v1/chat/completions", expectProvider: config.ProviderOpenAI, }, } @@ -351,10 +338,7 @@ func TestInjectedToolsTrace(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) var validatorFn func(*http.Request, []byte) if tc.expectProvider == config.ProviderAnthropic { @@ -363,23 +347,22 @@ func TestInjectedToolsTrace(t *testing.T) { validatorFn = openaiChatToolResultValidator(t) } - recorderClient, mockMCP, resp := setupInjectedToolTest( - t, tc.fixture, tc.streaming, tc.providerFn, tracer, userID, - tc.createReqFn, validatorFn, + bridgeServer, mockMCP, _ := setupInjectedToolTest( + t, tc.fixture, tc.streaming, tracer, + tc.path, validatorFn, tc.opts..., ) - defer resp.Body.Close() - require.Len(t, recorderClient.RecordedInterceptions(), 1) - intcID := recorderClient.RecordedInterceptions()[0].ID + require.Len(t, bridgeServer.Recorder.RecordedInterceptions(), 1) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID tool := mockMCP.ListTools()[0] attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.RequestPath, tc.path), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, tc.expectProvider), attribute.String(tracing.Model, tc.expectModel), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.String(tracing.MCPInput, `{"owner":"admin"}`), attribute.String(tracing.MCPToolName, "coder_list_workspaces"), attribute.String(tracing.MCPServerName, tool.ServerName), @@ -397,19 +380,18 @@ func TestInjectedToolsTrace(t *testing.T) { func TestTraceOpenAI(t *testing.T) { cases := []struct { - name string - fixture []byte - streaming bool - expectPath string - reqFunc func(t *testing.T, baseURL string, input []byte) *http.Request - expect []expectTrace + name string + fixture []byte + streaming bool + path string + + expect []expectTrace }{ { - name: "trace_openai_chat_streaming", - fixture: fixtures.OaiChatSimple, - streaming: true, - expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + name: "trace_openai_chat_streaming", + fixture: fixtures.OaiChatSimple, + streaming: true, + path: pathOpenAIChatCompletions, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -422,11 +404,10 @@ func TestTraceOpenAI(t *testing.T) { }, }, { - name: "trace_openai_chat_blocking", - fixture: fixtures.OaiChatSimple, - reqFunc: createOpenAIChatCompletionsReq, - streaming: false, - expectPath: "/openai/v1/chat/completions", + name: "trace_openai_chat_blocking", + fixture: fixtures.OaiChatSimple, + streaming: false, + path: pathOpenAIChatCompletions, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -439,11 +420,10 @@ func TestTraceOpenAI(t *testing.T) { }, }, { - name: "trace_openai_responses_streaming", - fixture: fixtures.OaiResponsesStreamingSimple, - streaming: true, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + name: "trace_openai_responses_streaming", + fixture: fixtures.OaiResponsesStreamingSimple, + streaming: true, + path: pathOpenAIResponses, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -456,11 +436,10 @@ func TestTraceOpenAI(t *testing.T) { }, }, { - name: "trace_openai_responses_blocking", - fixture: fixtures.OaiResponsesBlockingSimple, - streaming: false, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + name: "trace_openai_responses_blocking", + fixture: fixtures.OaiResponsesBlockingSimple, + streaming: false, + path: pathOpenAIResponses, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -479,28 +458,22 @@ func TestTraceOpenAI(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + withTracer(tracer), + ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.reqFunc(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() - srv.Close() + bridgeServer.Close() - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -509,11 +482,11 @@ func TestTraceOpenAI(t *testing.T) { require.Len(t, sr.Ended(), totalCount) attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.RequestPath, tc.path), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -527,17 +500,16 @@ func TestTraceOpenAIErr(t *testing.T) { fixture []byte streaming bool allowOverflow bool - expectPath string - reqFunc func(t *testing.T, baseURL string, input []byte) *http.Request - expect []expectTrace - expectCode int + path string + + expect []expectTrace + expectCode int }{ { name: "trace_openai_chat_streaming_error", fixture: fixtures.OaiChatMidStreamError, streaming: true, - expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -553,8 +525,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_chat_blocking_error", fixture: fixtures.OaiChatNonStreamError, streaming: false, - expectPath: "/openai/v1/chat/completions", - reqFunc: createOpenAIChatCompletionsReq, + path: pathOpenAIChatCompletions, expectCode: http.StatusBadRequest, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -569,8 +540,7 @@ func TestTraceOpenAIErr(t *testing.T) { name: "trace_openai_responses_streaming_error", streaming: true, fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectCode: http.StatusOK, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -583,11 +553,10 @@ func TestTraceOpenAIErr(t *testing.T) { }, }, { - name: "trace_openai_responses_blocking_error", - fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, - streaming: false, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + name: "trace_openai_responses_blocking_error", + fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, + streaming: false, + path: pathOpenAIResponses, // Fixture returns http 200 response with wrong body // responses forward received response as is so // expected code == 200 even though ProcessRequest @@ -608,8 +577,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: true, allowOverflow: true, // 429 error causes retries - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectCode: http.StatusTooManyRequests, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -625,8 +593,7 @@ func TestTraceOpenAIErr(t *testing.T) { fixture: fixtures.OaiResponsesBlockingHttpErr, streaming: false, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, + path: pathOpenAIResponses, expectCode: http.StatusUnauthorized, expect: []expectTrace{ {"Intercept", 1, codes.Error}, @@ -644,31 +611,25 @@ func TestTraceOpenAIErr(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - mockAPI := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix)) mockAPI.AllowOverflow = tc.allowOverflow - prov := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) - srv, recorder := newTestSrv(t, ctx, prov, nil, tracer) + bridgeServer := newBridgeTestServer(t, ctx, mockAPI.URL, + withTracer(tracer), + ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - req := tc.reqFunc(t, srv.URL, reqBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) require.Equal(t, tc.expectCode, resp.StatusCode) - defer resp.Body.Close() - srv.Close() + bridgeServer.Close() - require.Equal(t, 1, len(recorder.RecordedInterceptions())) - intcID := recorder.RecordedInterceptions()[0].ID + require.Equal(t, 1, len(bridgeServer.Recorder.RecordedInterceptions())) + intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -677,11 +638,11 @@ func TestTraceOpenAIErr(t *testing.T) { require.Len(t, sr.Ended(), totalCount) attrs := []attribute.KeyValue{ - attribute.String(tracing.RequestPath, tc.expectPath), + attribute.String(tracing.RequestPath, tc.path), attribute.String(tracing.InterceptionID, intcID), attribute.String(tracing.Provider, config.ProviderOpenAI), attribute.String(tracing.Model, gjson.Get(string(reqBody), "model").Str), - attribute.String(tracing.InitiatorID, userID), + attribute.String(tracing.InitiatorID, defaultActorID), attribute.Bool(tracing.Streaming, tc.streaming), } verifyTraces(t, sr, tc.expect, attrs) @@ -694,24 +655,17 @@ func TestTracePassthrough(t *testing.T) { fix := fixtures.Parse(t, fixtures.OaiChatFallthrough) - upstream := testutil.NewMockUpstream(t, t.Context(), testutil.NewFixtureResponse(fix)) - - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) - provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) - srv, _ := newTestSrv(t, t.Context(), provider, nil, tracer) + sr, tracer := setupTracer(t) - req, err := http.NewRequestWithContext(t.Context(), "GET", srv.URL+"/openai/v1/models", nil) - require.NoError(t, err) + bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, + withTracer(tracer), + ) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() + resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) require.Equal(t, http.StatusOK, resp.StatusCode) - srv.Close() + bridgeServer.Close() spans := sr.Ended() require.Len(t, spans, 1) @@ -727,13 +681,10 @@ func TestTracePassthrough(t *testing.T) { } func TestNewServerProxyManagerTraces(t *testing.T) { - sr := tracetest.NewSpanRecorder() - tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) - tracer := tp.Tracer(t.Name()) - defer func() { _ = tp.Shutdown(t.Context()) }() + sr, tracer := setupTracer(t) serverName := "serverName" - mockMCP := testutil.SetupMCPForTestWithName(t, serverName, tracer) + mockMCP := setupMCPForTestWithName(t, serverName, tracer) tool := mockMCP.ListTools()[0] require.Len(t, sr.Ended(), 3) @@ -775,14 +726,3 @@ func verifyTraces(t *testing.T, spanRecorder *tracetest.SpanRecorder, expect []e } } } - -func testBedrockCfg(url string) *config.AWSBedrock { - return &config.AWSBedrock{ - Region: "us-west-2", - AccessKey: "test-access-key", - AccessKeySecret: "test-secret-key", - Model: "beddel", // This model should override the request's given one. - SmallFastModel: "modrock", // Unused but needed for validation. - BaseURL: url, - } -} diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go index ac39006..09bcac3 100644 --- a/internal/testutil/mock_recorder.go +++ b/internal/testutil/mock_recorder.go @@ -74,6 +74,28 @@ func (m *MockRecorder) RecordedTokenUsages() []*recorder.TokenUsageRecord { return slices.Clone(m.tokenUsages) } +// TotalInputTokens returns the sum of input tokens across all recorded token usages. +func (m *MockRecorder) TotalInputTokens() int64 { + m.mu.Lock() + defer m.mu.Unlock() + var total int64 + for _, el := range m.tokenUsages { + total += el.Input + } + return total +} + +// TotalOutputTokens returns the sum of output tokens across all recorded token usages. +func (m *MockRecorder) TotalOutputTokens() int64 { + m.mu.Lock() + defer m.mu.Unlock() + var total int64 + for _, el := range m.tokenUsages { + total += el.Output + } + return total +} + // RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner. // Note: This is a shallow clone (see RecordedTokenUsages for details). func (m *MockRecorder) RecordedPromptUsages() []*recorder.PromptUsageRecord {