diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index ad76dad..1eea595 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -10,9 +10,18 @@ import ( "golang.org/x/oauth2" ) -const ( - azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68" -) +type AzureTenant struct { + DnsZone string + AzureApplicationID string +} + +var azureTenants = map[string]string{ + ".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc", + ".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68", + ".azuredatabricks.net": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", + ".databricks.azure.us": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", + ".databricks.azure.cn": "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", +} func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) { if ctx == nil { @@ -52,7 +61,8 @@ func GetScopes(hostName string, scopes []string) []string { cloudType := InferCloudFromHost(hostName) if cloudType == Azure { - userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId) + + userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenants[GetAzureDnsZone(hostName)]) if !HasScope(scopes, userImpersonationScope) { scopes = append(scopes, userImpersonationScope) } @@ -133,3 +143,12 @@ func InferCloudFromHost(hostname string) CloudType { } return Unknown } + +func GetAzureDnsZone(hostname string) string { + for _, d := range databricksAzureDomains { + if strings.Contains(hostname, d) { + return d + } + } + return "" +}