From 301d56507424911c18080d448b2a6730ebfbe12a Mon Sep 17 00:00:00 2001 From: Pavlo Chernykh <526266+pavlo-v-chernykh@users.noreply.github.com> Date: Sun, 29 Mar 2026 17:11:44 -0400 Subject: [PATCH 1/2] feat: add Device Authorization Grant (RFC 8628) as login fallback When DCR is unavailable (e.g. Okta SSO), auth login now falls back to the OAuth 2.0 Device Authorization Grant. The user approves login on a verification page instead of a local redirect. --- README.md | 6 +- internal/auth/auth.go | 136 ++++++++----- internal/auth/device.go | 250 +++++++++++++++++++++++ internal/auth/device_test.go | 338 ++++++++++++++++++++++++++++++++ internal/auth/discovery.go | 5 +- internal/auth/discovery_test.go | 22 +++ snippets/readme/snippet-04.sh | 2 +- 7 files changed, 706 insertions(+), 53 deletions(-) create mode 100644 internal/auth/device.go create mode 100644 internal/auth/device_test.go diff --git a/README.md b/README.md index 56c3139..89722a0 100644 --- a/README.md +++ b/README.md @@ -98,14 +98,14 @@ glean search "engineering docs" --output ndjson | jq .title ### OAuth (recommended) ```bash snippet=readme/snippet-04.sh -glean auth login # opens browser, completes PKCE flow +glean auth login # browser PKCE flow, or device flow for SSO/Okta glean auth status # verify credentials, host, and token expiry glean auth logout # remove all stored credentials ``` -OAuth uses PKCE with Dynamic Client Registration — no client ID required. Tokens are stored securely in the system keyring and refreshed automatically. +OAuth uses PKCE with Dynamic Client Registration when available. For SSO configurations where DCR is unavailable (e.g. Okta), `auth login` falls back to the Device Authorization Grant (RFC 8628) — you'll approve the login on a verification page instead. Tokens are stored securely in the system keyring and refreshed automatically. -For instances that don't support OAuth, `auth login` falls back to prompting for an API token. +For instances that don't support OAuth at all, `auth login` falls back to prompting for an API token. ### API Token (CI/CD) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 21e137b..1c0ae60 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -24,51 +24,59 @@ import ( //go:embed success.html var successHTML string -// Login performs the full OAuth 2.0 PKCE login flow for the configured Glean host. -// If the host is not configured, prompts for a work email and auto-discovers it. -// If the instance doesn't support OAuth, falls back to an inline API token prompt. +// Login performs the full OAuth 2.0 login flow for the configured Glean host. +// +// Strategy (in order): +// 1. Authorization Code + PKCE via DCR or static client +// 2. Device Authorization Grant (RFC 8628) using the Glean-advertised client ID +// 3. Inline API token prompt when OAuth is not available at all func Login(ctx context.Context) error { host, err := resolveHost(ctx) if err != nil { return err } - provider, endpoint, registrationEndpoint, err := discover(ctx, host) + disc, err := discover(ctx, host) if err != nil { fmt.Fprintf(os.Stderr, "\nOAuth discovery failed: %v\n", err) return promptForAPIToken(host) } - // Find a free port for the local callback server. - // This must happen before DCR so we register the exact redirect URI - // that oauth2cli will use — a mismatch causes a silent hang. + // Try DCR / static client first (standard authorization code flow). + if authCodeErr := tryAuthCodeLogin(ctx, host, disc); authCodeErr == nil { + return nil + } else if disc.DeviceFlowClientID != "" && disc.DeviceAuthEndpoint != "" { + fmt.Fprintf(os.Stderr, "Note: browser login failed (%v), trying device flow…\n", authCodeErr) + return deviceFlowLogin(ctx, host, disc) + } else { + return fmt.Errorf("authentication failed: %w", authCodeErr) + } +} + +// tryAuthCodeLogin attempts the Authorization Code + PKCE flow via DCR or static client. +func tryAuthCodeLogin(ctx context.Context, host string, disc *discoveryResult) error { port, err := findFreePort() if err != nil { return fmt.Errorf("finding callback port: %w", err) } redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port) - // Always do fresh DCR per login — the redirect URI (port) changes each time. - clientID, clientSecret, err := dcrOrStaticClient(ctx, host, registrationEndpoint, redirectURI) + clientID, clientSecret, err := dcrOrStaticClient(ctx, host, disc.RegistrationEndpoint, redirectURI) if err != nil { - return fmt.Errorf("resolving OAuth client: %w", err) + return err } verifier := oauth2.GenerateVerifier() - scopes := resolveScopes(provider) + scopes := resolveScopes(disc.Provider) oauthCfg := oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - Endpoint: endpoint, + Endpoint: disc.Endpoint, Scopes: scopes, RedirectURL: redirectURI, } - // oauth2cli v1.15.1 does not open the browser itself — the caller must do it. - // LocalServerReadyChan receives the local server URL once the callback server - // is ready. We open the browser to that URL (which the local server redirects - // to the real OAuth page), and also print the direct auth URL as a fallback. state := oauth2.GenerateVerifier()[:20] authURL := oauthCfg.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) @@ -80,7 +88,6 @@ func Login(ctx context.Context) error { fmt.Printf("If your browser doesn't open, visit:\n %s\n\n", authURL) fmt.Printf("Waiting for you to complete login in the browser…\n") if err := browser.OpenURL(localURL); err != nil { - // Browser failed to open — the printed URL is the fallback. fmt.Printf("(Could not open browser automatically: %v)\n", err) } case <-ctx.Done(): @@ -88,11 +95,8 @@ func Login(ctx context.Context) error { }() token, err := oauth2cli.GetToken(ctx, oauth2cli.Config{ - OAuth2Config: oauthCfg, - State: state, - // LocalServerBindAddress and LocalServerCallbackPath must match the - // redirect_uri registered via DCR exactly. oauth2cli constructs the - // redirect URL from LocalServerBindAddress (127.0.0.1:{port}) + path. + OAuth2Config: oauthCfg, + State: state, LocalServerCallbackPath: "/callback", LocalServerBindAddress: []string{fmt.Sprintf("127.0.0.1:%d", port)}, LocalServerReadyChan: readyChan, @@ -104,7 +108,14 @@ func Login(ctx context.Context) error { return fmt.Errorf("OAuth login failed: %w", err) } - email := extractEmailFromToken(ctx, provider, clientID, token) + return saveAndPrintToken(ctx, host, disc, oauthCfg.ClientID, token) +} + +// saveAndPrintToken persists the OAuth token and client, then prints a success message. +func saveAndPrintToken(ctx context.Context, host string, disc *discoveryResult, clientID string, token *oauth2.Token) error { + _ = SaveClient(host, &StoredClient{ClientID: clientID}) + + email := extractEmailFromToken(ctx, disc.Provider, clientID, token) stored := &StoredTokens{ AccessToken: token.AccessToken, @@ -112,7 +123,7 @@ func Login(ctx context.Context) error { Expiry: token.Expiry, Email: email, TokenType: token.TokenType, - TokenEndpoint: oauthCfg.Endpoint.TokenURL, // enables future token refresh + TokenEndpoint: disc.Endpoint.TokenURL, } if err := persistLoginState(host, stored); err != nil { return err @@ -315,6 +326,15 @@ func resolveHost(ctx context.Context) (string, error) { return host, nil } +// discoveryResult holds all OAuth metadata discovered for a Glean backend. +type discoveryResult struct { + Provider *oidc.Provider + Endpoint oauth2.Endpoint + RegistrationEndpoint string + DeviceFlowClientID string + DeviceAuthEndpoint string +} + // discover resolves the OAuth2 endpoint and registration endpoint for the Glean backend. // // Strategy: @@ -322,14 +342,11 @@ func resolveHost(ctx context.Context) (string, error) { // 2. Try OIDC discovery (oidc.NewProvider) for full OIDC support // 3. Fall back to RFC 8414 auth server metadata when OIDC is unavailable // (Glean uses RFC 8414 but does not serve /.well-known/openid-configuration) -// -// Returns (provider, oauth2Endpoint, registrationEndpoint, error). -// provider is nil when only RFC 8414 discovery succeeded. -func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint, string, error) { +func discover(ctx context.Context, host string) (*discoveryResult, error) { baseURL := "https://" + host meta, err := fetchProtectedResource(ctx, baseURL) if err != nil { - return nil, oauth2.Endpoint{}, "", err + return nil, err } issuer := meta.AuthorizationServers[0] @@ -337,28 +354,52 @@ func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint // Try full OIDC discovery first (supports ID token, UserInfo). provider, err := oidc.NewProvider(ctx, issuer) if err == nil { - // Still need registration_endpoint, which oidc.Provider doesn't expose. - authMeta, _ := fetchAuthServerMetadata(ctx, issuer) - regEndpoint := "" - if authMeta != nil { - regEndpoint = authMeta.RegistrationEndpoint + res := &discoveryResult{Provider: provider, Endpoint: provider.Endpoint()} + res.DeviceFlowClientID = meta.GleanDeviceFlowClientID + + // Extract device_authorization_endpoint from OIDC provider claims + // (RFC 8414 metadata may omit it even when OIDC metadata includes it). + var providerClaims struct { + RegistrationEndpoint string `json:"registration_endpoint"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` } - return provider, provider.Endpoint(), regEndpoint, nil + if err := provider.Claims(&providerClaims); err == nil { + res.RegistrationEndpoint = providerClaims.RegistrationEndpoint + res.DeviceAuthEndpoint = providerClaims.DeviceAuthorizationEndpoint + } + + // Supplement from RFC 8414 if OIDC claims were incomplete. + if res.RegistrationEndpoint == "" || res.DeviceAuthEndpoint == "" { + if authMeta, err := fetchAuthServerMetadata(ctx, issuer); err == nil { + if res.RegistrationEndpoint == "" { + res.RegistrationEndpoint = authMeta.RegistrationEndpoint + } + if res.DeviceAuthEndpoint == "" { + res.DeviceAuthEndpoint = authMeta.DeviceAuthorizationEndpoint + } + } + } + return res, nil } // Fall back to RFC 8414 auth server metadata. authMeta, err := fetchAuthServerMetadata(ctx, issuer) if err != nil { - return nil, oauth2.Endpoint{}, "", fmt.Errorf("OAuth discovery failed for %s: %w", issuer, err) + return nil, fmt.Errorf("OAuth discovery failed for %s: %w", issuer, err) } if authMeta.AuthorizationEndpoint == "" || authMeta.TokenEndpoint == "" { - return nil, oauth2.Endpoint{}, "", fmt.Errorf("OAuth metadata missing required endpoints for %s", issuer) - } - - return nil, oauth2.Endpoint{ - AuthURL: authMeta.AuthorizationEndpoint, - TokenURL: authMeta.TokenEndpoint, - }, authMeta.RegistrationEndpoint, nil + return nil, fmt.Errorf("OAuth metadata missing required endpoints for %s", issuer) + } + + return &discoveryResult{ + Endpoint: oauth2.Endpoint{ + AuthURL: authMeta.AuthorizationEndpoint, + TokenURL: authMeta.TokenEndpoint, + }, + RegistrationEndpoint: authMeta.RegistrationEndpoint, + DeviceFlowClientID: meta.GleanDeviceFlowClientID, + DeviceAuthEndpoint: authMeta.DeviceAuthorizationEndpoint, + }, nil } // dcrOrStaticClient resolves the OAuth client_id/secret for a login session. @@ -460,10 +501,11 @@ func fetchAuthServerMetadata(ctx context.Context, issuer string) (*authServerMet } type authServerMeta struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` } // extractEmailFromToken pulls the user email from the token. diff --git a/internal/auth/device.go b/internal/auth/device.go new file mode 100644 index 0000000..d12ae86 --- /dev/null +++ b/internal/auth/device.go @@ -0,0 +1,250 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/pkg/browser" + "golang.org/x/oauth2" +) + +const ( + defaultPollInterval = 5 * time.Second + maxPollInterval = 60 * time.Second + defaultExpiresIn = 900 // 15 minutes + maxExpiresIn = 1800 +) + +// deviceAuthResponse is the response from the device authorization endpoint (RFC 8628 §3.2). +type deviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// deviceTokenError is the error response from the token endpoint during polling. +type deviceTokenError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// deviceFlowLogin performs the OAuth 2.0 Device Authorization Grant (RFC 8628). +func deviceFlowLogin(ctx context.Context, host string, disc *discoveryResult) error { + scopes := resolveScopes(disc.Provider) + + authResp, err := requestDeviceCode(ctx, disc.DeviceAuthEndpoint, disc.DeviceFlowClientID, scopes) + if err != nil { + return fmt.Errorf("device authorization request failed: %w", err) + } + + verificationURL := authResp.VerificationURIComplete + if verificationURL == "" { + verificationURL = authResp.VerificationURI + } + + parsed, err := url.Parse(verificationURL) + if err != nil || parsed.Host == "" { + return fmt.Errorf("device authorization returned invalid verification URL: %q", verificationURL) + } + if parsed.Scheme != "https" { + return fmt.Errorf("device authorization returned non-HTTPS verification URL: %q", verificationURL) + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s\n\n", verificationURL) + if authResp.VerificationURIComplete == "" { + fmt.Printf("Then enter code: %s\n\n", authResp.UserCode) + } else { + fmt.Printf("Your code: %s\n\n", authResp.UserCode) + } + fmt.Printf("Waiting for you to complete login in the browser…\n") + + _ = browser.OpenURL(verificationURL) + + token, err := pollForToken(ctx, disc.Endpoint.TokenURL, disc.DeviceFlowClientID, authResp) + if err != nil { + return fmt.Errorf("device flow login failed: %w", err) + } + + return saveAndPrintToken(ctx, host, disc, disc.DeviceFlowClientID, token) +} + +// requestDeviceCode sends the initial device authorization request (RFC 8628 §3.1). +func requestDeviceCode(ctx context.Context, endpoint, clientID string, scopes []string) (*deviceAuthResponse, error) { + data := url.Values{ + "client_id": {clientID}, + "scope": {strings.Join(scopes, " ")}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("building device authorization request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := discoveryHTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("device authorization HTTP request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var errResp deviceTokenError + _ = json.NewDecoder(resp.Body).Decode(&errResp) + desc := errResp.ErrorDescription + if desc == "" { + desc = errResp.Error + } + if errResp.Error == "unauthorized_client" { + return nil, fmt.Errorf("%s\n\nAsk your IdP administrator to add the device_code grant type\nto OAuth app %s", desc, clientID) + } + if desc != "" { + return nil, fmt.Errorf("device authorization failed: %s", desc) + } + return nil, fmt.Errorf("device authorization endpoint returned HTTP %d", resp.StatusCode) + } + + var authResp deviceAuthResponse + if err := json.NewDecoder(resp.Body).Decode(&authResp); err != nil { + return nil, fmt.Errorf("parsing device authorization response: %w", err) + } + if authResp.DeviceCode == "" { + return nil, fmt.Errorf("device authorization response missing device_code") + } + if authResp.VerificationURI == "" && authResp.VerificationURIComplete == "" { + return nil, fmt.Errorf("device authorization response missing verification_uri") + } + authResp.Interval = clampInt(authResp.Interval, int(defaultPollInterval/time.Second), int(maxPollInterval/time.Second)) + if authResp.ExpiresIn <= 0 { + authResp.ExpiresIn = defaultExpiresIn + } else if authResp.ExpiresIn > maxExpiresIn { + authResp.ExpiresIn = maxExpiresIn + } + return &authResp, nil +} + +func clampInt(v, min, max int) int { + if v < min { + return min + } + if v > max { + return max + } + return v +} + +// pollForToken polls the token endpoint until the user completes authorization (RFC 8628 §3.4–3.5). +func pollForToken(ctx context.Context, tokenURL, clientID string, authResp *deviceAuthResponse) (*oauth2.Token, error) { + interval := time.Duration(authResp.Interval) * time.Second + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + + if time.Now().After(deadline) { + return nil, fmt.Errorf("device code expired — run 'glean auth login' to try again") + } + + token, status, err := exchangeDeviceCode(ctx, tokenURL, clientID, authResp.DeviceCode) + if err != nil { + return nil, err + } + if status == pollSlowDown { + interval += 5 * time.Second + continue + } + if status == pollPending { + continue + } + return token, nil + } +} + +type pollStatus int + +const ( + pollDone pollStatus = iota + pollPending + pollSlowDown +) + +// exchangeDeviceCode attempts a single token exchange for a device code. +func exchangeDeviceCode(ctx context.Context, tokenURL, clientID, deviceCode string) (*oauth2.Token, pollStatus, error) { + data := url.Values{ + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "client_id": {clientID}, + "device_code": {deviceCode}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, pollDone, fmt.Errorf("building token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := discoveryHTTPClient.Do(req) + if err != nil { + return nil, pollDone, fmt.Errorf("token exchange HTTP request: %w", err) + } + defer resp.Body.Close() + + body := json.NewDecoder(resp.Body) + + if resp.StatusCode != http.StatusOK { + var tokenErr deviceTokenError + _ = body.Decode(&tokenErr) + switch tokenErr.Error { + case "authorization_pending": + return nil, pollPending, nil + case "slow_down": + return nil, pollSlowDown, nil + case "expired_token": + return nil, pollDone, fmt.Errorf("device code expired — run 'glean auth login' to try again") + case "access_denied": + return nil, pollDone, fmt.Errorf("authorization denied by user") + default: + desc := tokenErr.ErrorDescription + if desc == "" { + desc = tokenErr.Error + } + return nil, pollDone, fmt.Errorf("token request failed: %s", desc) + } + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + } + if err := body.Decode(&tokenResp); err != nil { + return nil, pollDone, fmt.Errorf("parsing token response: %w", err) + } + if tokenResp.AccessToken == "" { + return nil, pollDone, fmt.Errorf("token response missing access_token") + } + + token := &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + RefreshToken: tokenResp.RefreshToken, + } + if tokenResp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + return token, pollDone, nil +} diff --git a/internal/auth/device_test.go b/internal/auth/device_test.go new file mode 100644 index 0000000..9ae4259 --- /dev/null +++ b/internal/auth/device_test.go @@ -0,0 +1,338 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + _ = r.ParseForm() + assert.Equal(t, "client-id", r.FormValue("client_id")) + assert.Contains(t, r.FormValue("scope"), "openid") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dev-code", + "user_code": "USER-1", + "verification_uri": "https://idp.example/verify", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + resp, err := requestDeviceCode(context.Background(), srv.URL, "client-id", []string{"openid", "profile"}) + require.NoError(t, err) + assert.Equal(t, "dev-code", resp.DeviceCode) + assert.Equal(t, "USER-1", resp.UserCode) + assert.Equal(t, "https://idp.example/verify", resp.VerificationURI) +} + +func TestRequestDeviceCode_UnauthorizedClient(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized_client", + "error_description": "client cannot use this grant", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, err := requestDeviceCode(context.Background(), srv.URL, "my-client", []string{"openid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "client cannot use this grant") + assert.Contains(t, err.Error(), "device_code grant type") + assert.Contains(t, err.Error(), "my-client") +} + +func TestRequestDeviceCode_MissingDeviceCode(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "user_code": "U", + "verification_uri": "https://x", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing device_code") +} + +func TestRequestDeviceCode_IntervalAndExpiresIn(t *testing.T) { + cases := []struct { + name string + rawInterval int + rawExpiresIn int + wantInterval int + wantExpiresIn int + }{ + {"interval_below_min_clamped_to_5", 1, 100, 5, 100}, + {"interval_above_max_clamped_to_60", 100, 100, 60, 100}, + {"interval_in_range_unchanged", 30, 100, 30, 100}, + {"expires_in_zero_defaults_to_900", 5, 0, 5, defaultExpiresIn}, + {"expires_in_capped_at_1800", 5, 99999, 5, maxExpiresIn}, + {"expires_in_in_range_unchanged", 5, 600, 5, 600}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dc", + "verification_uri": "https://v", + "interval": tc.rawInterval, + "expires_in": tc.rawExpiresIn, + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + resp, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.NoError(t, err) + assert.Equal(t, tc.wantInterval, resp.Interval) + assert.Equal(t, tc.wantExpiresIn, resp.ExpiresIn) + }) + } +} + +func TestRequestDeviceCode_MissingVerificationURI(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dc", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing verification_uri") +} + +func TestExchangeDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + _ = r.ParseForm() + assert.Equal(t, "urn:ietf:params:oauth:grant-type:device_code", r.FormValue("grant_type")) + assert.Equal(t, "cid", r.FormValue("client_id")) + assert.Equal(t, "dev", r.FormValue("device_code")) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Equal(t, pollDone, status) + require.NotNil(t, tok) + assert.Equal(t, "tok-123", tok.AccessToken) + assert.Equal(t, "Bearer", tok.TokenType) +} + +func TestExchangeDeviceCode_AuthorizationPending(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Nil(t, tok) + assert.Equal(t, pollPending, status) +} + +func TestExchangeDeviceCode_SlowDown(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Nil(t, tok) + assert.Equal(t, pollSlowDown, status) +} + +func TestExchangeDeviceCode_AccessDenied(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "access_denied"}) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization denied") + assert.Nil(t, tok) + assert.Equal(t, pollDone, status) +} + +func TestExchangeDeviceCode_EmptyAccessToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "access_token": "", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing access_token") + assert.Equal(t, pollDone, status) +} + +func TestPollForToken_PendingThenSuccess(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + if n == 1 { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + return + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "final", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 0, + ExpiresIn: 60, + VerificationURI: "https://v", + } + + tok, err := pollForToken(context.Background(), srv.URL, "cid", auth) + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "final", tok.AccessToken) + assert.Equal(t, int32(2), calls.Load()) +} + +func TestPollForToken_SlowDownIncreasesInterval(t *testing.T) { + if testing.Short() { + t.Skip("timing-based test waits ~5s after slow_down") + } + var calls atomic.Int32 + var firstAt, secondAt atomic.Int64 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + now := time.Now().UnixNano() + if n == 1 { + firstAt.Store(now) + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + return + } + secondAt.Store(now) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "ok", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 0, + ExpiresIn: 120, + VerificationURI: "https://v", + } + + start := time.Now() + tok, err := pollForToken(context.Background(), srv.URL, "cid", auth) + elapsed := time.Since(start) + + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "ok", tok.AccessToken) + assert.Equal(t, int32(2), calls.Load()) + + gap := time.Duration(secondAt.Load() - firstAt.Load()) + assert.GreaterOrEqual(t, gap, 4*time.Second, + "expected ~5s wait after slow_down increased interval; gap=%v", gap) + assert.GreaterOrEqual(t, elapsed, 4*time.Second) +} + +func TestPollForToken_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 10, + ExpiresIn: 3600, + VerificationURI: "https://v", + } + + _, err := pollForToken(ctx, "http://unused.example/token", "cid", auth) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +// overrideDiscoveryHTTPClient swaps the package-level client used by device/discovery helpers. +// The returned function restores the previous client (call in defer). +func overrideDiscoveryHTTPClient(cl *http.Client) func() { + prev := discoveryHTTPClient + discoveryHTTPClient = cl + return func() { discoveryHTTPClient = prev } +} diff --git a/internal/auth/discovery.go b/internal/auth/discovery.go index d7e25b0..4c3e5fc 100644 --- a/internal/auth/discovery.go +++ b/internal/auth/discovery.go @@ -22,8 +22,9 @@ func (e *ErrOAuthNotSupported) Error() string { } type protectedResourceMetadata struct { - Resource string `json:"resource"` - AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + GleanDeviceFlowClientID string `json:"glean_device_flow_client_id,omitempty"` } // fetchProtectedResource fetches RFC 9728 protected resource metadata. diff --git a/internal/auth/discovery_test.go b/internal/auth/discovery_test.go index 6e7fecf..fb84594 100644 --- a/internal/auth/discovery_test.go +++ b/internal/auth/discovery_test.go @@ -22,11 +22,33 @@ func TestFetchProtectedResource_Success(t *testing.T) { })) defer srv.Close() + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + result, err := fetchProtectedResource(context.Background(), srv.URL) require.NoError(t, err) assert.Equal(t, []string{"https://auth.example.com"}, result.AuthorizationServers) } +func TestFetchProtectedResource_DeviceFlowClientID(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/.well-known/oauth-protected-resource", r.URL.Path) + json.NewEncoder(w).Encode(map[string]any{ + "resource": "https://example.glean.com", + "authorization_servers": []string{"https://auth.example.com"}, + "glean_device_flow_client_id": "device-flow-client-123", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + result, err := fetchProtectedResource(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, "device-flow-client-123", result.GleanDeviceFlowClientID) +} + func TestFetchProtectedResource_NotFound(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) diff --git a/snippets/readme/snippet-04.sh b/snippets/readme/snippet-04.sh index 029400d..c49cece 100644 --- a/snippets/readme/snippet-04.sh +++ b/snippets/readme/snippet-04.sh @@ -1,3 +1,3 @@ -glean auth login # opens browser, completes PKCE flow +glean auth login # browser PKCE flow, or device flow for SSO/Okta glean auth status # verify credentials, host, and token expiry glean auth logout # remove all stored credentials From 0478ef9b89cdf578ad486e7c277b33e1b2e7f885 Mon Sep 17 00:00:00 2001 From: Pavlo Chernykh <526266+pavlo-v-chernykh@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:55:07 -0400 Subject: [PATCH 2/2] fix: narrow device flow fallback to DCR-unavailable errors only The previous fallback triggered on any tryAuthCodeLogin failure, including transient issues like network timeouts or the user closing their browser. Now device flow only activates when dcrOrStaticClient returns errNoOAuthClient (no registration endpoint + no static client), not when DCR was attempted and failed. Made-with: Cursor --- internal/auth/auth.go | 34 ++++++--- internal/auth/auth_fallback_test.go | 103 ++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 8 deletions(-) create mode 100644 internal/auth/auth_fallback_test.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 1c0ae60..32cdae8 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,6 +6,7 @@ import ( _ "embed" "encoding/base64" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -24,6 +25,12 @@ import ( //go:embed success.html var successHTML string +// errNoOAuthClient is returned by dcrOrStaticClient when neither DCR nor a +// static client is available. Login uses this to decide whether device flow +// is an appropriate fallback (as opposed to transient failures like network +// timeouts or the user closing their browser). +var errNoOAuthClient = errors.New("no OAuth client available") + // Login performs the full OAuth 2.0 login flow for the configured Glean host. // // Strategy (in order): @@ -43,14 +50,22 @@ func Login(ctx context.Context) error { } // Try DCR / static client first (standard authorization code flow). - if authCodeErr := tryAuthCodeLogin(ctx, host, disc); authCodeErr == nil { + authCodeErr := tryAuthCodeLogin(ctx, host, disc) + if authCodeErr == nil { return nil - } else if disc.DeviceFlowClientID != "" && disc.DeviceAuthEndpoint != "" { - fmt.Fprintf(os.Stderr, "Note: browser login failed (%v), trying device flow…\n", authCodeErr) + } + + // Only fall back to device flow when the auth code flow failed because no + // OAuth client could be obtained (DCR unsupported + no static client). + // Transient failures (network, user closing browser, port conflicts) should + // not silently switch to a different grant type. + canDeviceFlow := disc.DeviceFlowClientID != "" && disc.DeviceAuthEndpoint != "" + if errors.Is(authCodeErr, errNoOAuthClient) && canDeviceFlow { + fmt.Fprintf(os.Stderr, "Note: no OAuth client available, trying device flow…\n") return deviceFlowLogin(ctx, host, disc) - } else { - return fmt.Errorf("authentication failed: %w", authCodeErr) } + + return fmt.Errorf("authentication failed: %w", authCodeErr) } // tryAuthCodeLogin attempts the Authorization Code + PKCE flow via DCR or static client. @@ -408,14 +423,14 @@ func discover(ctx context.Context, host string) (*discoveryResult, error) { // credentials can be reused for token refresh later. // Falls back to a static client configured via glean config --oauth-client-id. func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirectURI string) (string, string, error) { + var dcrErr error if registrationEndpoint != "" { cl, err := registerClient(ctx, registrationEndpoint, redirectURI) if err == nil { - // Persist so future token refresh can use the same client credentials. _ = SaveClient(host, cl) return cl.ClientID, cl.ClientSecret, nil } - // DCR failed — log and fall through to static client + dcrErr = err fmt.Printf("Note: dynamic client registration failed (%v), trying static client\n", err) } @@ -424,7 +439,10 @@ func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirect return cfg.OAuthClientID, cfg.OAuthClientSecret, nil } - return "", "", fmt.Errorf("no OAuth client available — dynamic client registration failed and no static client is configured") + if dcrErr != nil { + return "", "", fmt.Errorf("%w: dynamic client registration failed (%v) and no static client is configured", errNoOAuthClient, dcrErr) + } + return "", "", fmt.Errorf("%w: no registration endpoint and no static client configured", errNoOAuthClient) } // resolveScopes returns the appropriate OAuth scopes for the given provider. diff --git a/internal/auth/auth_fallback_test.go b/internal/auth/auth_fallback_test.go new file mode 100644 index 0000000..234624a --- /dev/null +++ b/internal/auth/auth_fallback_test.go @@ -0,0 +1,103 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gleanwork/glean-cli/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDcrOrStaticClient_NoClientAvailable(t *testing.T) { + t.Setenv("GLEAN_HOST", "") + config.ConfigPath = t.TempDir() + "/config.json" + + _, _, err := dcrOrStaticClient(context.Background(), "test-host", "", "http://127.0.0.1:9999/callback") + require.Error(t, err) + assert.True(t, errors.Is(err, errNoOAuthClient), "expected errNoOAuthClient, got: %v", err) +} + +func TestDcrOrStaticClient_DCRFails_NoStaticClient(t *testing.T) { + t.Setenv("GLEAN_HOST", "") + config.ConfigPath = t.TempDir() + "/config.json" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, _, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.Error(t, err) + assert.True(t, errors.Is(err, errNoOAuthClient), + "DCR rejection (403) with no static client means no OAuth client is available") + assert.Contains(t, err.Error(), "dynamic client registration failed") +} + +func TestDcrOrStaticClient_DCRFails_StaticClientFallback(t *testing.T) { + dir := t.TempDir() + config.ConfigPath = dir + "/config.json" + t.Setenv("GLEAN_HOST", "test-host") + + cfgData, _ := json.Marshal(map[string]string{ + "host": "test-host", + "oauth_client_id": "static-id", + "oauth_client_secret": "static-secret", + }) + require.NoError(t, os.WriteFile(config.ConfigPath, cfgData, 0o600)) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + clientID, clientSecret, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.NoError(t, err) + assert.Equal(t, "static-id", clientID) + assert.Equal(t, "static-secret", clientSecret) +} + +func TestDcrOrStaticClient_DCRSucceeds(t *testing.T) { + dir := t.TempDir() + config.ConfigPath = dir + "/config.json" + setStoragePath(t, dir) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]string{ + "client_id": "dcr-id", + "client_secret": "dcr-secret", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + clientID, clientSecret, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.NoError(t, err) + assert.Equal(t, "dcr-id", clientID) + assert.Equal(t, "dcr-secret", clientSecret) +} + +func TestErrNoOAuthClient_NotMatchedByOtherErrors(t *testing.T) { + other := errors.New("finding callback port: address already in use") + assert.False(t, errors.Is(other, errNoOAuthClient)) +} + +// setStoragePath points token/client storage at a temp directory. +func setStoragePath(t *testing.T, dir string) { + t.Helper() + t.Setenv("GLEAN_AUTH_DIR", dir) +}