diff --git a/README.md b/README.md index 56c3139..89722a0 100644 --- a/README.md +++ b/README.md @@ -98,14 +98,14 @@ glean search "engineering docs" --output ndjson | jq .title ### OAuth (recommended) ```bash snippet=readme/snippet-04.sh -glean auth login # opens browser, completes PKCE flow +glean auth login # browser PKCE flow, or device flow for SSO/Okta glean auth status # verify credentials, host, and token expiry glean auth logout # remove all stored credentials ``` -OAuth uses PKCE with Dynamic Client Registration — no client ID required. Tokens are stored securely in the system keyring and refreshed automatically. +OAuth uses PKCE with Dynamic Client Registration when available. For SSO configurations where DCR is unavailable (e.g. Okta), `auth login` falls back to the Device Authorization Grant (RFC 8628) — you'll approve the login on a verification page instead. Tokens are stored securely in the system keyring and refreshed automatically. -For instances that don't support OAuth, `auth login` falls back to prompting for an API token. +For instances that don't support OAuth at all, `auth login` falls back to prompting for an API token. ### API Token (CI/CD) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 21e137b..67339a9 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,6 +6,7 @@ import ( _ "embed" "encoding/base64" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -24,51 +25,73 @@ import ( //go:embed success.html var successHTML string -// Login performs the full OAuth 2.0 PKCE login flow for the configured Glean host. -// If the host is not configured, prompts for a work email and auto-discovers it. -// If the instance doesn't support OAuth, falls back to an inline API token prompt. +// errNoOAuthClient is returned by dcrOrStaticClient when neither DCR nor a +// static client is available. Login uses this to decide whether device flow +// is an appropriate fallback (as opposed to transient failures like network +// timeouts or the user closing their browser). +var errNoOAuthClient = errors.New("no OAuth client available") + +// Login performs the full OAuth 2.0 login flow for the configured Glean host. +// +// Strategy (in order): +// 1. Authorization Code + PKCE via DCR or static client +// 2. Device Authorization Grant (RFC 8628) using the Glean-advertised client ID +// 3. Inline API token prompt when OAuth is not available at all func Login(ctx context.Context) error { host, err := resolveHost(ctx) if err != nil { return err } - provider, endpoint, registrationEndpoint, err := discover(ctx, host) + disc, err := discover(ctx, host) if err != nil { fmt.Fprintf(os.Stderr, "\nOAuth discovery failed: %v\n", err) return promptForAPIToken(host) } - // Find a free port for the local callback server. - // This must happen before DCR so we register the exact redirect URI - // that oauth2cli will use — a mismatch causes a silent hang. + // Try DCR / static client first (standard authorization code flow). + authCodeErr := tryAuthCodeLogin(ctx, host, disc) + if authCodeErr == nil { + return nil + } + + // Only fall back to device flow when the auth code flow failed because no + // OAuth client could be obtained (DCR unsupported + no static client). + // Transient failures (network, user closing browser, port conflicts) should + // not silently switch to a different grant type. + canDeviceFlow := disc.DeviceFlowClientID != "" && disc.DeviceAuthEndpoint != "" + if errors.Is(authCodeErr, errNoOAuthClient) && canDeviceFlow { + fmt.Fprintf(os.Stderr, "Note: no OAuth client available, trying device flow…\n") + return deviceFlowLogin(ctx, host, disc) + } + + return fmt.Errorf("authentication failed: %w", authCodeErr) +} + +// tryAuthCodeLogin attempts the Authorization Code + PKCE flow via DCR or static client. +func tryAuthCodeLogin(ctx context.Context, host string, disc *discoveryResult) error { port, err := findFreePort() if err != nil { return fmt.Errorf("finding callback port: %w", err) } redirectURI := fmt.Sprintf("http://127.0.0.1:%d/callback", port) - // Always do fresh DCR per login — the redirect URI (port) changes each time. - clientID, clientSecret, err := dcrOrStaticClient(ctx, host, registrationEndpoint, redirectURI) + clientID, clientSecret, err := dcrOrStaticClient(ctx, host, disc.RegistrationEndpoint, redirectURI) if err != nil { - return fmt.Errorf("resolving OAuth client: %w", err) + return err } verifier := oauth2.GenerateVerifier() - scopes := resolveScopes(provider) + scopes := resolveScopes(disc.Provider) oauthCfg := oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - Endpoint: endpoint, + Endpoint: disc.Endpoint, Scopes: scopes, RedirectURL: redirectURI, } - // oauth2cli v1.15.1 does not open the browser itself — the caller must do it. - // LocalServerReadyChan receives the local server URL once the callback server - // is ready. We open the browser to that URL (which the local server redirects - // to the real OAuth page), and also print the direct auth URL as a fallback. state := oauth2.GenerateVerifier()[:20] authURL := oauthCfg.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) @@ -80,7 +103,6 @@ func Login(ctx context.Context) error { fmt.Printf("If your browser doesn't open, visit:\n %s\n\n", authURL) fmt.Printf("Waiting for you to complete login in the browser…\n") if err := browser.OpenURL(localURL); err != nil { - // Browser failed to open — the printed URL is the fallback. fmt.Printf("(Could not open browser automatically: %v)\n", err) } case <-ctx.Done(): @@ -88,11 +110,8 @@ func Login(ctx context.Context) error { }() token, err := oauth2cli.GetToken(ctx, oauth2cli.Config{ - OAuth2Config: oauthCfg, - State: state, - // LocalServerBindAddress and LocalServerCallbackPath must match the - // redirect_uri registered via DCR exactly. oauth2cli constructs the - // redirect URL from LocalServerBindAddress (127.0.0.1:{port}) + path. + OAuth2Config: oauthCfg, + State: state, LocalServerCallbackPath: "/callback", LocalServerBindAddress: []string{fmt.Sprintf("127.0.0.1:%d", port)}, LocalServerReadyChan: readyChan, @@ -104,7 +123,14 @@ func Login(ctx context.Context) error { return fmt.Errorf("OAuth login failed: %w", err) } - email := extractEmailFromToken(ctx, provider, clientID, token) + return saveAndPrintToken(ctx, host, disc, oauthCfg.ClientID, token) +} + +// saveAndPrintToken persists the OAuth token and client, then prints a success message. +func saveAndPrintToken(ctx context.Context, host string, disc *discoveryResult, clientID string, token *oauth2.Token) error { + _ = SaveClient(host, &StoredClient{ClientID: clientID}) + + email := extractEmailFromToken(ctx, disc.Provider, clientID, token) stored := &StoredTokens{ AccessToken: token.AccessToken, @@ -112,7 +138,7 @@ func Login(ctx context.Context) error { Expiry: token.Expiry, Email: email, TokenType: token.TokenType, - TokenEndpoint: oauthCfg.Endpoint.TokenURL, // enables future token refresh + TokenEndpoint: disc.Endpoint.TokenURL, } if err := persistLoginState(host, stored); err != nil { return err @@ -315,6 +341,15 @@ func resolveHost(ctx context.Context) (string, error) { return host, nil } +// discoveryResult holds all OAuth metadata discovered for a Glean backend. +type discoveryResult struct { + Provider *oidc.Provider + Endpoint oauth2.Endpoint + RegistrationEndpoint string + DeviceFlowClientID string + DeviceAuthEndpoint string +} + // discover resolves the OAuth2 endpoint and registration endpoint for the Glean backend. // // Strategy: @@ -322,14 +357,11 @@ func resolveHost(ctx context.Context) (string, error) { // 2. Try OIDC discovery (oidc.NewProvider) for full OIDC support // 3. Fall back to RFC 8414 auth server metadata when OIDC is unavailable // (Glean uses RFC 8414 but does not serve /.well-known/openid-configuration) -// -// Returns (provider, oauth2Endpoint, registrationEndpoint, error). -// provider is nil when only RFC 8414 discovery succeeded. -func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint, string, error) { +func discover(ctx context.Context, host string) (*discoveryResult, error) { baseURL := "https://" + host meta, err := fetchProtectedResource(ctx, baseURL) if err != nil { - return nil, oauth2.Endpoint{}, "", err + return nil, err } issuer := meta.AuthorizationServers[0] @@ -337,28 +369,52 @@ func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint // Try full OIDC discovery first (supports ID token, UserInfo). provider, err := oidc.NewProvider(ctx, issuer) if err == nil { - // Still need registration_endpoint, which oidc.Provider doesn't expose. - authMeta, _ := fetchAuthServerMetadata(ctx, issuer) - regEndpoint := "" - if authMeta != nil { - regEndpoint = authMeta.RegistrationEndpoint + res := &discoveryResult{Provider: provider, Endpoint: provider.Endpoint()} + res.DeviceFlowClientID = meta.GleanDeviceFlowClientID + + // Extract device_authorization_endpoint from OIDC provider claims + // (RFC 8414 metadata may omit it even when OIDC metadata includes it). + var providerClaims struct { + RegistrationEndpoint string `json:"registration_endpoint"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` } - return provider, provider.Endpoint(), regEndpoint, nil + if err := provider.Claims(&providerClaims); err == nil { + res.RegistrationEndpoint = providerClaims.RegistrationEndpoint + res.DeviceAuthEndpoint = providerClaims.DeviceAuthorizationEndpoint + } + + // Supplement from RFC 8414 if OIDC claims were incomplete. + if res.RegistrationEndpoint == "" || res.DeviceAuthEndpoint == "" { + if authMeta, err := fetchAuthServerMetadata(ctx, issuer); err == nil { + if res.RegistrationEndpoint == "" { + res.RegistrationEndpoint = authMeta.RegistrationEndpoint + } + if res.DeviceAuthEndpoint == "" { + res.DeviceAuthEndpoint = authMeta.DeviceAuthorizationEndpoint + } + } + } + return res, nil } // Fall back to RFC 8414 auth server metadata. authMeta, err := fetchAuthServerMetadata(ctx, issuer) if err != nil { - return nil, oauth2.Endpoint{}, "", fmt.Errorf("OAuth discovery failed for %s: %w", issuer, err) + return nil, fmt.Errorf("OAuth discovery failed for %s: %w", issuer, err) } if authMeta.AuthorizationEndpoint == "" || authMeta.TokenEndpoint == "" { - return nil, oauth2.Endpoint{}, "", fmt.Errorf("OAuth metadata missing required endpoints for %s", issuer) - } - - return nil, oauth2.Endpoint{ - AuthURL: authMeta.AuthorizationEndpoint, - TokenURL: authMeta.TokenEndpoint, - }, authMeta.RegistrationEndpoint, nil + return nil, fmt.Errorf("OAuth metadata missing required endpoints for %s", issuer) + } + + return &discoveryResult{ + Endpoint: oauth2.Endpoint{ + AuthURL: authMeta.AuthorizationEndpoint, + TokenURL: authMeta.TokenEndpoint, + }, + RegistrationEndpoint: authMeta.RegistrationEndpoint, + DeviceFlowClientID: meta.GleanDeviceFlowClientID, + DeviceAuthEndpoint: authMeta.DeviceAuthorizationEndpoint, + }, nil } // dcrOrStaticClient resolves the OAuth client_id/secret for a login session. @@ -367,14 +423,14 @@ func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint // credentials can be reused for token refresh later. // Falls back to a static client configured via glean config --oauth-client-id. func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirectURI string) (string, string, error) { + var dcrErr error if registrationEndpoint != "" { cl, err := registerClient(ctx, registrationEndpoint, redirectURI) if err == nil { - // Persist so future token refresh can use the same client credentials. _ = SaveClient(host, cl) return cl.ClientID, cl.ClientSecret, nil } - // DCR failed — log and fall through to static client + dcrErr = err fmt.Printf("Note: dynamic client registration failed (%v), trying static client\n", err) } @@ -383,7 +439,10 @@ func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirect return cfg.OAuthClientID, cfg.OAuthClientSecret, nil } - return "", "", fmt.Errorf("no OAuth client available — dynamic client registration failed and no static client is configured") + if dcrErr != nil { + return "", "", fmt.Errorf("dynamic client registration failed and no static client is configured: %w", dcrErr) + } + return "", "", fmt.Errorf("%w: no registration endpoint and no static client configured", errNoOAuthClient) } // resolveScopes returns the appropriate OAuth scopes for the given provider. @@ -460,10 +519,11 @@ func fetchAuthServerMetadata(ctx context.Context, issuer string) (*authServerMet } type authServerMeta struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` } // extractEmailFromToken pulls the user email from the token. diff --git a/internal/auth/auth_fallback_test.go b/internal/auth/auth_fallback_test.go new file mode 100644 index 0000000..d202cc3 --- /dev/null +++ b/internal/auth/auth_fallback_test.go @@ -0,0 +1,103 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gleanwork/glean-cli/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDcrOrStaticClient_NoClientAvailable(t *testing.T) { + t.Setenv("GLEAN_HOST", "") + config.ConfigPath = t.TempDir() + "/config.json" + + _, _, err := dcrOrStaticClient(context.Background(), "test-host", "", "http://127.0.0.1:9999/callback") + require.Error(t, err) + assert.True(t, errors.Is(err, errNoOAuthClient), "expected errNoOAuthClient, got: %v", err) +} + +func TestDcrOrStaticClient_DCRFails_NoStaticClient(t *testing.T) { + t.Setenv("GLEAN_HOST", "") + config.ConfigPath = t.TempDir() + "/config.json" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, _, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.Error(t, err) + assert.False(t, errors.Is(err, errNoOAuthClient), + "DCR failure should not match errNoOAuthClient (would trigger device flow on transient errors)") + assert.Contains(t, err.Error(), "dynamic client registration failed") +} + +func TestDcrOrStaticClient_DCRFails_StaticClientFallback(t *testing.T) { + dir := t.TempDir() + config.ConfigPath = dir + "/config.json" + t.Setenv("GLEAN_HOST", "test-host") + + cfgData, _ := json.Marshal(map[string]string{ + "host": "test-host", + "oauth_client_id": "static-id", + "oauth_client_secret": "static-secret", + }) + require.NoError(t, os.WriteFile(config.ConfigPath, cfgData, 0o600)) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + clientID, clientSecret, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.NoError(t, err) + assert.Equal(t, "static-id", clientID) + assert.Equal(t, "static-secret", clientSecret) +} + +func TestDcrOrStaticClient_DCRSucceeds(t *testing.T) { + dir := t.TempDir() + config.ConfigPath = dir + "/config.json" + setStoragePath(t, dir) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]string{ + "client_id": "dcr-id", + "client_secret": "dcr-secret", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + clientID, clientSecret, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.NoError(t, err) + assert.Equal(t, "dcr-id", clientID) + assert.Equal(t, "dcr-secret", clientSecret) +} + +func TestErrNoOAuthClient_NotMatchedByOtherErrors(t *testing.T) { + other := errors.New("finding callback port: address already in use") + assert.False(t, errors.Is(other, errNoOAuthClient)) +} + +// setStoragePath points token/client storage at a temp directory. +func setStoragePath(t *testing.T, dir string) { + t.Helper() + t.Setenv("GLEAN_AUTH_DIR", dir) +} diff --git a/internal/auth/device.go b/internal/auth/device.go new file mode 100644 index 0000000..d12ae86 --- /dev/null +++ b/internal/auth/device.go @@ -0,0 +1,250 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/pkg/browser" + "golang.org/x/oauth2" +) + +const ( + defaultPollInterval = 5 * time.Second + maxPollInterval = 60 * time.Second + defaultExpiresIn = 900 // 15 minutes + maxExpiresIn = 1800 +) + +// deviceAuthResponse is the response from the device authorization endpoint (RFC 8628 §3.2). +type deviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// deviceTokenError is the error response from the token endpoint during polling. +type deviceTokenError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// deviceFlowLogin performs the OAuth 2.0 Device Authorization Grant (RFC 8628). +func deviceFlowLogin(ctx context.Context, host string, disc *discoveryResult) error { + scopes := resolveScopes(disc.Provider) + + authResp, err := requestDeviceCode(ctx, disc.DeviceAuthEndpoint, disc.DeviceFlowClientID, scopes) + if err != nil { + return fmt.Errorf("device authorization request failed: %w", err) + } + + verificationURL := authResp.VerificationURIComplete + if verificationURL == "" { + verificationURL = authResp.VerificationURI + } + + parsed, err := url.Parse(verificationURL) + if err != nil || parsed.Host == "" { + return fmt.Errorf("device authorization returned invalid verification URL: %q", verificationURL) + } + if parsed.Scheme != "https" { + return fmt.Errorf("device authorization returned non-HTTPS verification URL: %q", verificationURL) + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s\n\n", verificationURL) + if authResp.VerificationURIComplete == "" { + fmt.Printf("Then enter code: %s\n\n", authResp.UserCode) + } else { + fmt.Printf("Your code: %s\n\n", authResp.UserCode) + } + fmt.Printf("Waiting for you to complete login in the browser…\n") + + _ = browser.OpenURL(verificationURL) + + token, err := pollForToken(ctx, disc.Endpoint.TokenURL, disc.DeviceFlowClientID, authResp) + if err != nil { + return fmt.Errorf("device flow login failed: %w", err) + } + + return saveAndPrintToken(ctx, host, disc, disc.DeviceFlowClientID, token) +} + +// requestDeviceCode sends the initial device authorization request (RFC 8628 §3.1). +func requestDeviceCode(ctx context.Context, endpoint, clientID string, scopes []string) (*deviceAuthResponse, error) { + data := url.Values{ + "client_id": {clientID}, + "scope": {strings.Join(scopes, " ")}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("building device authorization request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := discoveryHTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("device authorization HTTP request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var errResp deviceTokenError + _ = json.NewDecoder(resp.Body).Decode(&errResp) + desc := errResp.ErrorDescription + if desc == "" { + desc = errResp.Error + } + if errResp.Error == "unauthorized_client" { + return nil, fmt.Errorf("%s\n\nAsk your IdP administrator to add the device_code grant type\nto OAuth app %s", desc, clientID) + } + if desc != "" { + return nil, fmt.Errorf("device authorization failed: %s", desc) + } + return nil, fmt.Errorf("device authorization endpoint returned HTTP %d", resp.StatusCode) + } + + var authResp deviceAuthResponse + if err := json.NewDecoder(resp.Body).Decode(&authResp); err != nil { + return nil, fmt.Errorf("parsing device authorization response: %w", err) + } + if authResp.DeviceCode == "" { + return nil, fmt.Errorf("device authorization response missing device_code") + } + if authResp.VerificationURI == "" && authResp.VerificationURIComplete == "" { + return nil, fmt.Errorf("device authorization response missing verification_uri") + } + authResp.Interval = clampInt(authResp.Interval, int(defaultPollInterval/time.Second), int(maxPollInterval/time.Second)) + if authResp.ExpiresIn <= 0 { + authResp.ExpiresIn = defaultExpiresIn + } else if authResp.ExpiresIn > maxExpiresIn { + authResp.ExpiresIn = maxExpiresIn + } + return &authResp, nil +} + +func clampInt(v, min, max int) int { + if v < min { + return min + } + if v > max { + return max + } + return v +} + +// pollForToken polls the token endpoint until the user completes authorization (RFC 8628 §3.4–3.5). +func pollForToken(ctx context.Context, tokenURL, clientID string, authResp *deviceAuthResponse) (*oauth2.Token, error) { + interval := time.Duration(authResp.Interval) * time.Second + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + + if time.Now().After(deadline) { + return nil, fmt.Errorf("device code expired — run 'glean auth login' to try again") + } + + token, status, err := exchangeDeviceCode(ctx, tokenURL, clientID, authResp.DeviceCode) + if err != nil { + return nil, err + } + if status == pollSlowDown { + interval += 5 * time.Second + continue + } + if status == pollPending { + continue + } + return token, nil + } +} + +type pollStatus int + +const ( + pollDone pollStatus = iota + pollPending + pollSlowDown +) + +// exchangeDeviceCode attempts a single token exchange for a device code. +func exchangeDeviceCode(ctx context.Context, tokenURL, clientID, deviceCode string) (*oauth2.Token, pollStatus, error) { + data := url.Values{ + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "client_id": {clientID}, + "device_code": {deviceCode}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, pollDone, fmt.Errorf("building token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := discoveryHTTPClient.Do(req) + if err != nil { + return nil, pollDone, fmt.Errorf("token exchange HTTP request: %w", err) + } + defer resp.Body.Close() + + body := json.NewDecoder(resp.Body) + + if resp.StatusCode != http.StatusOK { + var tokenErr deviceTokenError + _ = body.Decode(&tokenErr) + switch tokenErr.Error { + case "authorization_pending": + return nil, pollPending, nil + case "slow_down": + return nil, pollSlowDown, nil + case "expired_token": + return nil, pollDone, fmt.Errorf("device code expired — run 'glean auth login' to try again") + case "access_denied": + return nil, pollDone, fmt.Errorf("authorization denied by user") + default: + desc := tokenErr.ErrorDescription + if desc == "" { + desc = tokenErr.Error + } + return nil, pollDone, fmt.Errorf("token request failed: %s", desc) + } + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + } + if err := body.Decode(&tokenResp); err != nil { + return nil, pollDone, fmt.Errorf("parsing token response: %w", err) + } + if tokenResp.AccessToken == "" { + return nil, pollDone, fmt.Errorf("token response missing access_token") + } + + token := &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + RefreshToken: tokenResp.RefreshToken, + } + if tokenResp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + return token, pollDone, nil +} diff --git a/internal/auth/device_test.go b/internal/auth/device_test.go new file mode 100644 index 0000000..9ae4259 --- /dev/null +++ b/internal/auth/device_test.go @@ -0,0 +1,338 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + _ = r.ParseForm() + assert.Equal(t, "client-id", r.FormValue("client_id")) + assert.Contains(t, r.FormValue("scope"), "openid") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dev-code", + "user_code": "USER-1", + "verification_uri": "https://idp.example/verify", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + resp, err := requestDeviceCode(context.Background(), srv.URL, "client-id", []string{"openid", "profile"}) + require.NoError(t, err) + assert.Equal(t, "dev-code", resp.DeviceCode) + assert.Equal(t, "USER-1", resp.UserCode) + assert.Equal(t, "https://idp.example/verify", resp.VerificationURI) +} + +func TestRequestDeviceCode_UnauthorizedClient(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized_client", + "error_description": "client cannot use this grant", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, err := requestDeviceCode(context.Background(), srv.URL, "my-client", []string{"openid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "client cannot use this grant") + assert.Contains(t, err.Error(), "device_code grant type") + assert.Contains(t, err.Error(), "my-client") +} + +func TestRequestDeviceCode_MissingDeviceCode(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "user_code": "U", + "verification_uri": "https://x", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing device_code") +} + +func TestRequestDeviceCode_IntervalAndExpiresIn(t *testing.T) { + cases := []struct { + name string + rawInterval int + rawExpiresIn int + wantInterval int + wantExpiresIn int + }{ + {"interval_below_min_clamped_to_5", 1, 100, 5, 100}, + {"interval_above_max_clamped_to_60", 100, 100, 60, 100}, + {"interval_in_range_unchanged", 30, 100, 30, 100}, + {"expires_in_zero_defaults_to_900", 5, 0, 5, defaultExpiresIn}, + {"expires_in_capped_at_1800", 5, 99999, 5, maxExpiresIn}, + {"expires_in_in_range_unchanged", 5, 600, 5, 600}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dc", + "verification_uri": "https://v", + "interval": tc.rawInterval, + "expires_in": tc.rawExpiresIn, + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + resp, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.NoError(t, err) + assert.Equal(t, tc.wantInterval, resp.Interval) + assert.Equal(t, tc.wantExpiresIn, resp.ExpiresIn) + }) + } +} + +func TestRequestDeviceCode_MissingVerificationURI(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dc", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing verification_uri") +} + +func TestExchangeDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + _ = r.ParseForm() + assert.Equal(t, "urn:ietf:params:oauth:grant-type:device_code", r.FormValue("grant_type")) + assert.Equal(t, "cid", r.FormValue("client_id")) + assert.Equal(t, "dev", r.FormValue("device_code")) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Equal(t, pollDone, status) + require.NotNil(t, tok) + assert.Equal(t, "tok-123", tok.AccessToken) + assert.Equal(t, "Bearer", tok.TokenType) +} + +func TestExchangeDeviceCode_AuthorizationPending(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Nil(t, tok) + assert.Equal(t, pollPending, status) +} + +func TestExchangeDeviceCode_SlowDown(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Nil(t, tok) + assert.Equal(t, pollSlowDown, status) +} + +func TestExchangeDeviceCode_AccessDenied(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "access_denied"}) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization denied") + assert.Nil(t, tok) + assert.Equal(t, pollDone, status) +} + +func TestExchangeDeviceCode_EmptyAccessToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "access_token": "", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + _, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing access_token") + assert.Equal(t, pollDone, status) +} + +func TestPollForToken_PendingThenSuccess(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + if n == 1 { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + return + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "final", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 0, + ExpiresIn: 60, + VerificationURI: "https://v", + } + + tok, err := pollForToken(context.Background(), srv.URL, "cid", auth) + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "final", tok.AccessToken) + assert.Equal(t, int32(2), calls.Load()) +} + +func TestPollForToken_SlowDownIncreasesInterval(t *testing.T) { + if testing.Short() { + t.Skip("timing-based test waits ~5s after slow_down") + } + var calls atomic.Int32 + var firstAt, secondAt atomic.Int64 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + now := time.Now().UnixNano() + if n == 1 { + firstAt.Store(now) + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + return + } + secondAt.Store(now) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "ok", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 0, + ExpiresIn: 120, + VerificationURI: "https://v", + } + + start := time.Now() + tok, err := pollForToken(context.Background(), srv.URL, "cid", auth) + elapsed := time.Since(start) + + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "ok", tok.AccessToken) + assert.Equal(t, int32(2), calls.Load()) + + gap := time.Duration(secondAt.Load() - firstAt.Load()) + assert.GreaterOrEqual(t, gap, 4*time.Second, + "expected ~5s wait after slow_down increased interval; gap=%v", gap) + assert.GreaterOrEqual(t, elapsed, 4*time.Second) +} + +func TestPollForToken_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 10, + ExpiresIn: 3600, + VerificationURI: "https://v", + } + + _, err := pollForToken(ctx, "http://unused.example/token", "cid", auth) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} + +// overrideDiscoveryHTTPClient swaps the package-level client used by device/discovery helpers. +// The returned function restores the previous client (call in defer). +func overrideDiscoveryHTTPClient(cl *http.Client) func() { + prev := discoveryHTTPClient + discoveryHTTPClient = cl + return func() { discoveryHTTPClient = prev } +} diff --git a/internal/auth/discovery.go b/internal/auth/discovery.go index d7e25b0..4c3e5fc 100644 --- a/internal/auth/discovery.go +++ b/internal/auth/discovery.go @@ -22,8 +22,9 @@ func (e *ErrOAuthNotSupported) Error() string { } type protectedResourceMetadata struct { - Resource string `json:"resource"` - AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + GleanDeviceFlowClientID string `json:"glean_device_flow_client_id,omitempty"` } // fetchProtectedResource fetches RFC 9728 protected resource metadata. diff --git a/internal/auth/discovery_test.go b/internal/auth/discovery_test.go index 6e7fecf..fb84594 100644 --- a/internal/auth/discovery_test.go +++ b/internal/auth/discovery_test.go @@ -22,11 +22,33 @@ func TestFetchProtectedResource_Success(t *testing.T) { })) defer srv.Close() + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + result, err := fetchProtectedResource(context.Background(), srv.URL) require.NoError(t, err) assert.Equal(t, []string{"https://auth.example.com"}, result.AuthorizationServers) } +func TestFetchProtectedResource_DeviceFlowClientID(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/.well-known/oauth-protected-resource", r.URL.Path) + json.NewEncoder(w).Encode(map[string]any{ + "resource": "https://example.glean.com", + "authorization_servers": []string{"https://auth.example.com"}, + "glean_device_flow_client_id": "device-flow-client-123", + }) + })) + defer srv.Close() + + restore := overrideDiscoveryHTTPClient(srv.Client()) + defer restore() + + result, err := fetchProtectedResource(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, "device-flow-client-123", result.GleanDeviceFlowClientID) +} + func TestFetchProtectedResource_NotFound(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) diff --git a/snippets/readme/snippet-04.sh b/snippets/readme/snippet-04.sh index 029400d..c49cece 100644 --- a/snippets/readme/snippet-04.sh +++ b/snippets/readme/snippet-04.sh @@ -1,3 +1,3 @@ -glean auth login # opens browser, completes PKCE flow +glean auth login # browser PKCE flow, or device flow for SSO/Okta glean auth status # verify credentials, host, and token expiry glean auth logout # remove all stored credentials