Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ depends on the existing profiles you have set in your configuration file
return err
}

authArguments.Profile = profileName

var scopesList []string
if scopes != "" {
for _, s := range strings.Split(scopes, ",") {
Expand Down
2 changes: 2 additions & 0 deletions cmd/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) {
return nil, err
}

args.authArguments.Profile = args.profileName

ctx, cancel := context.WithTimeout(ctx, args.tokenTimeout)
defer cancel()
oauthArgument, err := args.authArguments.ToOAuthArgument()
Expand Down
7 changes: 7 additions & 0 deletions cmd/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ func TestToken_loadToken(t *testing.T) {
RefreshToken: "active",
Expiry: time.Now().Add(1 * time.Hour), // Hopefully unit tests don't take an hour to run
},
"expired": {
RefreshToken: "expired",
},
"active": {
RefreshToken: "active",
Expiry: time.Now().Add(1 * time.Hour),
},
},
}
validateToken := func(resp *oauth2.Token) {
Expand Down
10 changes: 7 additions & 3 deletions libs/auth/arguments.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ type AuthArguments struct {
AccountID string
WorkspaceID string
IsUnifiedHost bool

// Profile is the optional profile name. When set, the OAuth token cache
// key is the profile name instead of the host-based key.
Profile string
}

// ToOAuthArgument converts the AuthArguments to an OAuthArgument from the Go SDK.
Expand All @@ -28,13 +32,13 @@ func (a AuthArguments) ToOAuthArgument() (u2m.OAuthArgument, error) {

switch cfg.HostType() {
case config.AccountHost:
return u2m.NewBasicAccountOAuthArgument(host, cfg.AccountID)
return u2m.NewProfileAccountOAuthArgument(host, cfg.AccountID, a.Profile)
case config.WorkspaceHost:
return u2m.NewBasicWorkspaceOAuthArgument(host)
return u2m.NewProfileWorkspaceOAuthArgument(host, a.Profile)
case config.UnifiedHost:
// For unified hosts, always use the unified OAuth argument with account ID.
// The workspace ID is stored in the config for API routing, not OAuth.
return u2m.NewBasicUnifiedOAuthArgument(host, cfg.AccountID)
return u2m.NewProfileUnifiedOAuthArgument(host, cfg.AccountID, a.Profile)
default:
return nil, fmt.Errorf("unknown host type: %v", cfg.HostType())
}
Expand Down
65 changes: 52 additions & 13 deletions libs/auth/arguments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,80 @@ import (

func TestToOAuthArgument(t *testing.T) {
tests := []struct {
name string
args AuthArguments
wantHost string
wantError bool
name string
args AuthArguments
wantHost string
wantCacheKey string
wantError bool
}{
{
name: "workspace with no scheme",
args: AuthArguments{
Host: "my-workspace.cloud.databricks.com",
},
wantHost: "https://my-workspace.cloud.databricks.com",
wantHost: "https://my-workspace.cloud.databricks.com",
wantCacheKey: "https://my-workspace.cloud.databricks.com",
},
{
name: "workspace with https",
args: AuthArguments{
Host: "https://my-workspace.cloud.databricks.com",
},
wantHost: "https://my-workspace.cloud.databricks.com",
wantHost: "https://my-workspace.cloud.databricks.com",
wantCacheKey: "https://my-workspace.cloud.databricks.com",
},
{
name: "workspace with profile uses profile-based cache key",
args: AuthArguments{
Host: "https://my-workspace.cloud.databricks.com",
Profile: "my-profile",
},
wantHost: "https://my-workspace.cloud.databricks.com",
wantCacheKey: "my-profile",
},
{
name: "account with no scheme",
args: AuthArguments{
Host: "accounts.cloud.databricks.com",
AccountID: "123456789",
},
wantHost: "https://accounts.cloud.databricks.com",
wantHost: "https://accounts.cloud.databricks.com",
wantCacheKey: "https://accounts.cloud.databricks.com/oidc/accounts/123456789",
},
{
name: "account with https",
args: AuthArguments{
Host: "https://accounts.cloud.databricks.com",
AccountID: "123456789",
},
wantHost: "https://accounts.cloud.databricks.com",
wantHost: "https://accounts.cloud.databricks.com",
wantCacheKey: "https://accounts.cloud.databricks.com/oidc/accounts/123456789",
},
{
name: "account with profile uses profile-based cache key",
args: AuthArguments{
Host: "https://accounts.cloud.databricks.com",
AccountID: "123456789",
Profile: "my-account-profile",
},
wantHost: "https://accounts.cloud.databricks.com",
wantCacheKey: "my-account-profile",
},
{
name: "workspace with query parameter",
args: AuthArguments{
Host: "https://my-workspace.cloud.databricks.com?o=123456789",
},
wantHost: "https://my-workspace.cloud.databricks.com",
wantHost: "https://my-workspace.cloud.databricks.com",
wantCacheKey: "https://my-workspace.cloud.databricks.com",
},
{
name: "workspace with query parameter and path",
args: AuthArguments{
Host: "https://my-workspace.cloud.databricks.com/path?o=123456789",
},
wantHost: "https://my-workspace.cloud.databricks.com",
wantHost: "https://my-workspace.cloud.databricks.com",
wantCacheKey: "https://my-workspace.cloud.databricks.com",
},
{
name: "unified host with account ID only",
Expand All @@ -65,7 +91,8 @@ func TestToOAuthArgument(t *testing.T) {
AccountID: "123456789",
IsUnifiedHost: true,
},
wantHost: "https://unified.cloud.databricks.com",
wantHost: "https://unified.cloud.databricks.com",
wantCacheKey: "https://unified.cloud.databricks.com/oidc/accounts/123456789",
},
{
name: "unified host with both account ID and workspace ID",
Expand All @@ -75,7 +102,19 @@ func TestToOAuthArgument(t *testing.T) {
WorkspaceID: "123456789",
IsUnifiedHost: true,
},
wantHost: "https://unified.cloud.databricks.com",
wantHost: "https://unified.cloud.databricks.com",
wantCacheKey: "https://unified.cloud.databricks.com/oidc/accounts/123456789",
},
{
name: "unified host with profile uses profile-based cache key",
args: AuthArguments{
Host: "https://unified.cloud.databricks.com",
AccountID: "123456789",
IsUnifiedHost: true,
Profile: "my-unified-profile",
},
wantHost: "https://unified.cloud.databricks.com",
wantCacheKey: "my-unified-profile",
},
}

Expand All @@ -87,10 +126,10 @@ func TestToOAuthArgument(t *testing.T) {
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantCacheKey, got.GetCacheKey())

// Check if we got the right type of argument and verify the hostname
if tt.args.IsUnifiedHost {
// Unified hosts return UnifiedOAuthArgument (distinct from Account/Workspace)
arg, ok := got.(u2m.UnifiedOAuthArgument)
assert.True(t, ok, "expected UnifiedOAuthArgument for unified host")
assert.Equal(t, tt.wantHost, arg.GetHost())
Expand Down