diff --git a/cmd/api.go b/cmd/api.go index 5f15d36..1005314 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -15,6 +15,7 @@ import ( "github.com/briandowns/spinner" gleanClient "github.com/gleanwork/glean-cli/internal/client" "github.com/gleanwork/glean-cli/internal/config" + "github.com/gleanwork/glean-cli/internal/httputil" "github.com/gleanwork/glean-cli/internal/output" "github.com/spf13/cobra" "golang.org/x/term" @@ -193,7 +194,7 @@ func rawAPIRequest(ctx context.Context, cfg *config.Config, method, endpoint str req.Header.Set("X-Glean-Auth-Type", authType) } - httpClient := &http.Client{Timeout: 30 * time.Second} + httpClient := httputil.NewHTTPClient(30 * time.Second) httpResp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %w", err) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 21e137b..90739b5 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,6 +16,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gleanwork/glean-cli/internal/config" + "github.com/gleanwork/glean-cli/internal/httputil" "github.com/int128/oauth2cli" "github.com/pkg/browser" "golang.org/x/oauth2" @@ -447,7 +448,7 @@ func fetchAuthServerMetadata(ctx context.Context, issuer string) (*authServerMet return nil, err } req.Header.Set("Accept", "application/json") - resp, err := discoveryHTTPClient.Do(req) + resp, err := httputil.NewHTTPClient(10 * time.Second).Do(req) if err != nil { return nil, err } diff --git a/internal/auth/discovery.go b/internal/auth/discovery.go index d7e25b0..94b57f7 100644 --- a/internal/auth/discovery.go +++ b/internal/auth/discovery.go @@ -8,9 +8,9 @@ import ( "net/http" "strings" "time" -) -var discoveryHTTPClient = &http.Client{Timeout: 10 * time.Second} + "github.com/gleanwork/glean-cli/internal/httputil" +) // ErrOAuthNotSupported is returned when the protected resource endpoint returns 404. type ErrOAuthNotSupported struct { @@ -36,7 +36,7 @@ func fetchProtectedResource(ctx context.Context, baseURL string) (*protectedReso } req.Header.Set("Accept", "application/json") - resp, err := discoveryHTTPClient.Do(req) + resp, err := httputil.NewHTTPClient(10 * time.Second).Do(req) if err != nil { return nil, fmt.Errorf("fetching protected resource metadata: %w", err) } @@ -81,7 +81,7 @@ func registerClient(ctx context.Context, registrationEndpoint, redirectURI strin req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - resp, err := discoveryHTTPClient.Do(req) + resp, err := httputil.NewHTTPClient(10 * time.Second).Do(req) if err != nil { return nil, fmt.Errorf("DCR request failed: %w", err) } diff --git a/internal/auth/domainlookup.go b/internal/auth/domainlookup.go index 970bc6e..70527a7 100644 --- a/internal/auth/domainlookup.go +++ b/internal/auth/domainlookup.go @@ -8,12 +8,12 @@ import ( "net/http" "strings" "time" + + "github.com/gleanwork/glean-cli/internal/httputil" ) const gleanConfigSearchURL = "https://app.glean.com/config/search" -var domainLookupHTTPClient = &http.Client{Timeout: 10 * time.Second} - // LookupBackendURL resolves a work email to a Glean backend base URL // using Glean's domain discovery API. func LookupBackendURL(ctx context.Context, email string) (string, error) { @@ -43,7 +43,7 @@ func lookupBackendURL(ctx context.Context, email, endpoint string) (string, erro req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - resp, err := domainLookupHTTPClient.Do(req) + resp, err := httputil.NewHTTPClient(10 * time.Second).Do(req) if err != nil { return "", fmt.Errorf("domain lookup request failed: %w", err) } diff --git a/internal/client/client.go b/internal/client/client.go index a62990b..8683e1e 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -11,11 +11,9 @@ import ( glean "github.com/gleanwork/api-client-go" "github.com/gleanwork/glean-cli/internal/auth" "github.com/gleanwork/glean-cli/internal/config" + "github.com/gleanwork/glean-cli/internal/httputil" ) -// cliVersion is set at startup via SetVersion. Defaults to "dev" for local builds. -var cliVersion = "dev" - // authTypeOAuth is the X-Glean-Auth-Type header value required for External IdP OAuth tokens. const authTypeOAuth = "OAUTH" @@ -33,28 +31,6 @@ func ResolveToken(cfg *config.Config) (token, authType string) { return "", "" } -// SetVersion records the build-time version for use in the User-Agent header. -func SetVersion(v string) { cliVersion = v } - -// Version returns the current CLI version string. -func Version() string { return cliVersion } - -// cliTransport wraps an http.RoundTripper, sets the CLI User-Agent header, -// and injects X-Glean-Auth-Type when the token originates from OAuth. -type cliTransport struct { - base http.RoundTripper - authType string // "OAUTH" or "" (empty = API token, no header set) -} - -func (t *cliTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = req.Clone(req.Context()) - req.Header.Set("User-Agent", "glean-cli/"+cliVersion) - if t.authType != "" { - req.Header.Set("X-Glean-Auth-Type", t.authType) - } - return t.base.RoundTrip(req) -} - // New creates an authenticated Glean SDK client from the loaded configuration. // // Authentication priority: @@ -81,7 +57,9 @@ func New(cfg *config.Config) (*glean.Glean, error) { glean.WithInstance(instance), glean.WithSecurity(token), glean.WithClient(&http.Client{ - Transport: &cliTransport{base: http.DefaultTransport, authType: authType}, + Transport: httputil.NewTransport(http.DefaultTransport, + httputil.WithHeader("X-Glean-Auth-Type", authType), + ), }), } diff --git a/internal/client/client_test.go b/internal/client/client_test.go index efc5f38..ed62f54 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/gleanwork/glean-cli/internal/config" + "github.com/gleanwork/glean-cli/internal/httputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,36 +20,40 @@ func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return m.fn(req) } -func TestCLITransport_OAuthSetsHeader(t *testing.T) { +func TestTransport_OAuthSetsHeader(t *testing.T) { + httputil.SetVersion("test") + var captured *http.Request base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { captured = req return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil }} - transport := &cliTransport{base: base, authType: authTypeOAuth} + transport := httputil.NewTransport(base, httputil.WithHeader("X-Glean-Auth-Type", authTypeOAuth)) req, err := http.NewRequest("GET", "https://example.com", nil) require.NoError(t, err) _, _ = transport.RoundTrip(req) assert.Equal(t, authTypeOAuth, captured.Header.Get("X-Glean-Auth-Type")) - assert.Contains(t, captured.Header.Get("User-Agent"), "glean-cli/") + assert.Equal(t, "glean-cli/test", captured.Header.Get("User-Agent")) } -func TestCLITransport_APITokenOmitsHeader(t *testing.T) { +func TestTransport_APITokenOmitsHeader(t *testing.T) { + httputil.SetVersion("test") + var captured *http.Request base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { captured = req return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil }} - transport := &cliTransport{base: base, authType: ""} + transport := httputil.NewTransport(base) req, err := http.NewRequest("GET", "https://example.com", nil) require.NoError(t, err) _, _ = transport.RoundTrip(req) assert.Empty(t, captured.Header.Get("X-Glean-Auth-Type")) - assert.Contains(t, captured.Header.Get("User-Agent"), "glean-cli/") + assert.Equal(t, "glean-cli/test", captured.Header.Get("User-Agent")) } func TestResolveToken_APIToken(t *testing.T) { diff --git a/internal/client/stream.go b/internal/client/stream.go index 077639d..73f05ca 100644 --- a/internal/client/stream.go +++ b/internal/client/stream.go @@ -11,12 +11,13 @@ import ( "github.com/gleanwork/api-client-go/models/components" "github.com/gleanwork/glean-cli/internal/config" + "github.com/gleanwork/glean-cli/internal/httputil" ) -// streamHTTPClient has a generous timeout for long-running AUTO/ADVANCED agent +// streamTimeout is a generous timeout for long-running AUTO/ADVANCED agent // responses. Context cancellation (ctrl+c in the TUI) handles user-initiated // cancellation; this timeout is only a backstop for genuine network hangs. -var streamHTTPClient = &http.Client{Timeout: 10 * time.Minute} +const streamTimeout = 10 * time.Minute // StreamChat makes a streaming chat request to the Glean API, bypassing the // SDK's buffered CreateStream which reads the entire response before returning. @@ -65,12 +66,11 @@ func StreamChat(ctx context.Context, cfg *config.Config, req components.ChatRequ httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") httpReq.Header.Set("Authorization", "Bearer "+token) - httpReq.Header.Set("User-Agent", "glean-cli/"+cliVersion) if authType != "" { httpReq.Header.Set("X-Glean-Auth-Type", authType) } - resp, err := streamHTTPClient.Do(httpReq) + resp, err := httputil.NewHTTPClient(streamTimeout).Do(httpReq) if err != nil { return nil, fmt.Errorf("chat request failed: %w", err) } diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go new file mode 100644 index 0000000..56f73d0 --- /dev/null +++ b/internal/httputil/httputil.go @@ -0,0 +1,67 @@ +package httputil + +import ( + "net/http" + "time" +) + +// cliVersion is set at startup via SetVersion. Defaults to "dev" for local builds. +var cliVersion = "dev" + +// SetVersion records the build-time version for use in the User-Agent header. +func SetVersion(v string) { cliVersion = v } + +// Version returns the current CLI version string. +func Version() string { return cliVersion } + +// TransportOption configures a cliTransport. +type TransportOption func(*cliTransport) + +// WithHeader adds a static header to every outgoing request. +// If value is empty the header is not set. +func WithHeader(key, value string) TransportOption { + return func(t *cliTransport) { + if value != "" { + t.extraHeaders[key] = value + } + } +} + +// cliTransport wraps an http.RoundTripper, injects the CLI User-Agent header, +// and applies any additional static headers on every outgoing request. +type cliTransport struct { + base http.RoundTripper + extraHeaders map[string]string +} + +func (t *cliTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("User-Agent", "glean-cli/"+cliVersion) + for k, v := range t.extraHeaders { + req.Header.Set(k, v) + } + return t.base.RoundTrip(req) +} + +// NewTransport returns an http.RoundTripper that injects the CLI User-Agent +// header (and any extra headers from opts) before delegating to base. +// If base is nil, http.DefaultTransport is used. +func NewTransport(base http.RoundTripper, opts ...TransportOption) http.RoundTripper { + if base == nil { + base = http.DefaultTransport + } + t := &cliTransport{base: base, extraHeaders: make(map[string]string)} + for _, o := range opts { + o(t) + } + return t +} + +// NewHTTPClient returns an *http.Client with the given timeout whose transport +// injects the CLI User-Agent header on every request. +func NewHTTPClient(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + Transport: NewTransport(http.DefaultTransport), + } +} diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go new file mode 100644 index 0000000..992eff6 --- /dev/null +++ b/internal/httputil/httputil_test.go @@ -0,0 +1,122 @@ +package httputil + +import ( + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockRoundTripper struct { + fn func(*http.Request) (*http.Response, error) +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.fn(req) +} + +func TestUATransport_SetsUserAgent(t *testing.T) { + SetVersion("1.2.3") + + var captured *http.Request + base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + captured = req + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }} + + transport := NewTransport(base) + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + assert.Equal(t, "glean-cli/1.2.3", captured.Header.Get("User-Agent")) +} + +func TestUATransport_DefaultVersion(t *testing.T) { + SetVersion("dev") + + var captured *http.Request + base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + captured = req + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }} + + transport := NewTransport(base) + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + assert.Equal(t, "glean-cli/dev", captured.Header.Get("User-Agent")) +} + +func TestUATransport_DoesNotMutateOriginalRequest(t *testing.T) { + SetVersion("1.0.0") + + base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }} + + transport := NewTransport(base) + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + assert.Empty(t, req.Header.Get("User-Agent"), "original request should not be mutated") +} + +func TestWithHeader_SetsExtraHeader(t *testing.T) { + SetVersion("1.0.0") + + var captured *http.Request + base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + captured = req + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }} + + transport := NewTransport(base, WithHeader("X-Custom", "value")) + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + assert.Equal(t, "value", captured.Header.Get("X-Custom")) + assert.Equal(t, "glean-cli/1.0.0", captured.Header.Get("User-Agent")) +} + +func TestWithHeader_EmptyValueIsIgnored(t *testing.T) { + SetVersion("1.0.0") + + var captured *http.Request + base := &mockRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + captured = req + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }} + + transport := NewTransport(base, WithHeader("X-Custom", "")) + req, err := http.NewRequest("GET", "https://example.com", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + assert.Empty(t, captured.Header.Get("X-Custom")) +} + +func TestNewHTTPClient_SetsTransport(t *testing.T) { + SetVersion("2.0.0") + + client := NewHTTPClient(0) + require.NotNil(t, client) + + assert.NotNil(t, client.Transport) +} diff --git a/internal/update/check.go b/internal/update/check.go index 47aa05b..6b20993 100644 --- a/internal/update/check.go +++ b/internal/update/check.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/gleanwork/glean-cli/internal/httputil" "golang.org/x/mod/semver" ) @@ -75,7 +76,7 @@ func check(currentVersion string) string { } func fetchLatestTag() (string, error) { - client := &http.Client{Timeout: 5 * time.Second} + client := httputil.NewHTTPClient(5 * time.Second) resp, err := client.Get(releaseAPIURL) if err != nil { return "", err diff --git a/internal/update/upgrade.go b/internal/update/upgrade.go index d00b644..f08b590 100644 --- a/internal/update/upgrade.go +++ b/internal/update/upgrade.go @@ -16,6 +16,7 @@ import ( "strings" "time" + "github.com/gleanwork/glean-cli/internal/httputil" "github.com/minio/selfupdate" ) @@ -121,7 +122,7 @@ func assetFilename() string { // download fetches a URL and returns the body bytes. func download(url string) ([]byte, error) { - client := &http.Client{Timeout: 120 * time.Second} + client := httputil.NewHTTPClient(120 * time.Second) resp, err := client.Get(url) if err != nil { return nil, err diff --git a/main.go b/main.go index c79a301..7e9f3c9 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,7 @@ import ( "os" "github.com/gleanwork/glean-cli/cmd" - "github.com/gleanwork/glean-cli/internal/client" + "github.com/gleanwork/glean-cli/internal/httputil" ) // version is set at build time via ldflags: -X main.version= @@ -13,7 +13,7 @@ var version = "dev" func main() { cmd.SetVersion(version) - client.SetVersion(version) + httputil.SetVersion(version) if err := cmd.Execute(); err != nil { os.Exit(1) }