diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 56a0caa90a..33378f7cfc 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -152,6 +152,8 @@ depends on the existing profiles you have set in your configuration file return err } + authArguments.Profile = profileName + var scopesList []string if scopes != "" { for _, s := range strings.Split(scopes, ",") { diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 4987915e21..106c6b9f0b 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -113,6 +113,8 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { return nil, err } + args.authArguments.Profile = args.profileName + ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout) defer cancel() oauthArgument, err := args.authArguments.ToOAuthArgument() diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index e342b27e77..01c6e83ea3 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -100,6 +100,13 @@ func TestToken_loadToken(t *testing.T) { RefreshToken: "active", Expiry: time.Now().Add(1 * time.Hour), // Hopefully unit tests don't take an hour to run }, + "expired": { + RefreshToken: "expired", + }, + "active": { + RefreshToken: "active", + Expiry: time.Now().Add(1 * time.Hour), + }, }, } validateToken := func(resp *oauth2.Token) { diff --git a/libs/auth/arguments.go b/libs/auth/arguments.go index 007ea0867e..8e00d89507 100644 --- a/libs/auth/arguments.go +++ b/libs/auth/arguments.go @@ -14,6 +14,10 @@ type AuthArguments struct { AccountID string WorkspaceID string IsUnifiedHost bool + + // Profile is the optional profile name. When set, the OAuth token cache + // key is the profile name instead of the host-based key. + Profile string } // ToOAuthArgument converts the AuthArguments to an OAuthArgument from the Go SDK. @@ -28,13 +32,13 @@ func (a AuthArguments) ToOAuthArgument() (u2m.OAuthArgument, error) { switch cfg.HostType() { case config.AccountHost: - return u2m.NewBasicAccountOAuthArgument(host, cfg.AccountID) + return u2m.NewProfileAccountOAuthArgument(host, cfg.AccountID, a.Profile) case config.WorkspaceHost: - return u2m.NewBasicWorkspaceOAuthArgument(host) + return u2m.NewProfileWorkspaceOAuthArgument(host, a.Profile) case config.UnifiedHost: // For unified hosts, always use the unified OAuth argument with account ID. // The workspace ID is stored in the config for API routing, not OAuth. - return u2m.NewBasicUnifiedOAuthArgument(host, cfg.AccountID) + return u2m.NewProfileUnifiedOAuthArgument(host, cfg.AccountID, a.Profile) default: return nil, fmt.Errorf("unknown host type: %v", cfg.HostType()) } diff --git a/libs/auth/arguments_test.go b/libs/auth/arguments_test.go index 957061815e..7b41b9dbfd 100644 --- a/libs/auth/arguments_test.go +++ b/libs/auth/arguments_test.go @@ -9,24 +9,36 @@ import ( func TestToOAuthArgument(t *testing.T) { tests := []struct { - name string - args AuthArguments - wantHost string - wantError bool + name string + args AuthArguments + wantHost string + wantCacheKey string + wantError bool }{ { name: "workspace with no scheme", args: AuthArguments{ Host: "my-workspace.cloud.databricks.com", }, - wantHost: "https://my-workspace.cloud.databricks.com", + wantHost: "https://my-workspace.cloud.databricks.com", + wantCacheKey: "https://my-workspace.cloud.databricks.com", }, { name: "workspace with https", args: AuthArguments{ Host: "https://my-workspace.cloud.databricks.com", }, - wantHost: "https://my-workspace.cloud.databricks.com", + wantHost: "https://my-workspace.cloud.databricks.com", + wantCacheKey: "https://my-workspace.cloud.databricks.com", + }, + { + name: "workspace with profile uses profile-based cache key", + args: AuthArguments{ + Host: "https://my-workspace.cloud.databricks.com", + Profile: "my-profile", + }, + wantHost: "https://my-workspace.cloud.databricks.com", + wantCacheKey: "my-profile", }, { name: "account with no scheme", @@ -34,7 +46,8 @@ func TestToOAuthArgument(t *testing.T) { Host: "accounts.cloud.databricks.com", AccountID: "123456789", }, - wantHost: "https://accounts.cloud.databricks.com", + wantHost: "https://accounts.cloud.databricks.com", + wantCacheKey: "https://accounts.cloud.databricks.com/oidc/accounts/123456789", }, { name: "account with https", @@ -42,21 +55,34 @@ func TestToOAuthArgument(t *testing.T) { Host: "https://accounts.cloud.databricks.com", AccountID: "123456789", }, - wantHost: "https://accounts.cloud.databricks.com", + wantHost: "https://accounts.cloud.databricks.com", + wantCacheKey: "https://accounts.cloud.databricks.com/oidc/accounts/123456789", + }, + { + name: "account with profile uses profile-based cache key", + args: AuthArguments{ + Host: "https://accounts.cloud.databricks.com", + AccountID: "123456789", + Profile: "my-account-profile", + }, + wantHost: "https://accounts.cloud.databricks.com", + wantCacheKey: "my-account-profile", }, { name: "workspace with query parameter", args: AuthArguments{ Host: "https://my-workspace.cloud.databricks.com?o=123456789", }, - wantHost: "https://my-workspace.cloud.databricks.com", + wantHost: "https://my-workspace.cloud.databricks.com", + wantCacheKey: "https://my-workspace.cloud.databricks.com", }, { name: "workspace with query parameter and path", args: AuthArguments{ Host: "https://my-workspace.cloud.databricks.com/path?o=123456789", }, - wantHost: "https://my-workspace.cloud.databricks.com", + wantHost: "https://my-workspace.cloud.databricks.com", + wantCacheKey: "https://my-workspace.cloud.databricks.com", }, { name: "unified host with account ID only", @@ -65,7 +91,8 @@ func TestToOAuthArgument(t *testing.T) { AccountID: "123456789", IsUnifiedHost: true, }, - wantHost: "https://unified.cloud.databricks.com", + wantHost: "https://unified.cloud.databricks.com", + wantCacheKey: "https://unified.cloud.databricks.com/oidc/accounts/123456789", }, { name: "unified host with both account ID and workspace ID", @@ -75,7 +102,19 @@ func TestToOAuthArgument(t *testing.T) { WorkspaceID: "123456789", IsUnifiedHost: true, }, - wantHost: "https://unified.cloud.databricks.com", + wantHost: "https://unified.cloud.databricks.com", + wantCacheKey: "https://unified.cloud.databricks.com/oidc/accounts/123456789", + }, + { + name: "unified host with profile uses profile-based cache key", + args: AuthArguments{ + Host: "https://unified.cloud.databricks.com", + AccountID: "123456789", + IsUnifiedHost: true, + Profile: "my-unified-profile", + }, + wantHost: "https://unified.cloud.databricks.com", + wantCacheKey: "my-unified-profile", }, } @@ -87,10 +126,10 @@ func TestToOAuthArgument(t *testing.T) { return } assert.NoError(t, err) + assert.Equal(t, tt.wantCacheKey, got.GetCacheKey()) // Check if we got the right type of argument and verify the hostname if tt.args.IsUnifiedHost { - // Unified hosts return UnifiedOAuthArgument (distinct from Account/Workspace) arg, ok := got.(u2m.UnifiedOAuthArgument) assert.True(t, ok, "expected UnifiedOAuthArgument for unified host") assert.Equal(t, tt.wantHost, arg.GetHost())