diff --git a/cmd/azure-keyvault-controller/main.go b/cmd/azure-keyvault-controller/main.go index 2cd86e6f..4383cc59 100644 --- a/cmd/azure-keyvault-controller/main.go +++ b/cmd/azure-keyvault-controller/main.go @@ -176,16 +176,17 @@ func main() { eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeClient.CoreV1().Events("")}) var token azcore.TokenCredential + var keyVaultDNSSuffix string klog.Infof("use `%s` as authType", authType) switch authType { case "azureCloudConfig": - token, err = getCredentialsFromCloudConfig(cloudconfig) + token, keyVaultDNSSuffix, err = getCredentialsFromCloudConfig(cloudconfig) if err != nil { klog.ErrorS(err, "failed to create cloud config provider for azure key vault", "file", cloudconfig) os.Exit(1) } case "environment": - token, err = getCredentialsFromEnvironment() + token, keyVaultDNSSuffix, err = getCredentialsFromEnvironment() if err != nil { klog.ErrorS(err, "failed to create credentials provider from environment for azure key vault") os.Exit(1) @@ -197,7 +198,7 @@ func main() { klog.Infof(msg) }) } - token, err = getCredentialsFromAzidentity() + token, keyVaultDNSSuffix, err = getCredentialsFromAzidentity() if err != nil { klog.ErrorS(err, "failed to create credentials provider from azidentity for azure key vault") os.Exit(1) @@ -208,7 +209,7 @@ func main() { os.Exit(1) } - vaultService := vault.NewService(token) + vaultService := vault.NewService(token, keyVaultDNSSuffix) recorder := eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: controllerAgentName}) @@ -256,34 +257,49 @@ func healthHandler(w http.ResponseWriter, r *http.Request) { } } -func getCredentialsFromCloudConfig(cloudconfig string) (azure.LegacyTokenCredential, error) { +func getCredentialsFromCloudConfig(cloudconfig string) (azure.LegacyTokenCredential, string, error) { f, err := os.Open(cloudconfig) if err != nil { - return nil, fmt.Errorf("failed reading azure config from %s, error: %+v", cloudconfig, err) + return nil, "", fmt.Errorf("failed reading azure config from %s, error: %+v", cloudconfig, err) } defer f.Close() cloudCnfProvider, err := credentialprovider.NewFromCloudConfig(f) if err != nil { - return nil, fmt.Errorf("failed reading azure config from %s, error: %+v", cloudconfig, err) + return nil, "", fmt.Errorf("failed reading azure config from %s, error: %+v", cloudconfig, err) } - return cloudCnfProvider.GetAzureKeyVaultCredentials() + token, err := cloudCnfProvider.GetAzureKeyVaultCredentials() + if err != nil { + return nil, "", nil + } + + return token, cloudCnfProvider.GetAzureKeyVaultDNSSuffix(), err } -func getCredentialsFromEnvironment() (azure.LegacyTokenCredential, error) { +func getCredentialsFromEnvironment() (azure.LegacyTokenCredential, string, error) { provider, err := credentialprovider.NewFromEnvironment() if err != nil { - return nil, fmt.Errorf("failed to create azure credentials provider, error: %+v", err) + return nil, "", fmt.Errorf("failed to create azure credentials provider, error: %+v", err) } - return provider.GetAzureKeyVaultCredentials() + token, err := provider.GetAzureKeyVaultCredentials() + if err != nil { + return nil, "", nil + } + + return token, provider.GetAzureKeyVaultDNSSuffix(), err } -func getCredentialsFromAzidentity() (azure.LegacyTokenCredential, error) { +func getCredentialsFromAzidentity() (azure.LegacyTokenCredential, string, error) { provider, err := credentialprovider.NewFromAzidentity() if err != nil { - return nil, fmt.Errorf("failed to create azure identity provider, error: %+v", err) + return nil, "", fmt.Errorf("failed to create azure identity provider, error: %+v", err) + } + token, err := provider.GetAzureKeyVaultCredentials() + if err != nil { + return nil, "", nil } - return provider.GetAzureKeyVaultCredentials() + + return token, provider.GetAzureKeyVaultDNSSuffix(), err } diff --git a/cmd/azure-keyvault-env/authentication.go b/cmd/azure-keyvault-env/authentication.go index a4c02a05..ed19864b 100644 --- a/cmd/azure-keyvault-env/authentication.go +++ b/cmd/azure-keyvault-env/authentication.go @@ -87,17 +87,17 @@ func createMtlsClient(clientCertDir string) (*http.Client, error) { return client, nil } -func getCredentials() (azure.LegacyTokenCredential, error) { +func getCredentials() (azure.LegacyTokenCredential, string, error) { provider, err := credentialprovider.NewFromEnvironment() if err != nil { - return nil, fmt.Errorf("failed to create credentials provider for azure key vault, error: %w", err) + return nil, "", fmt.Errorf("failed to create credentials provider for azure key vault, error: %w", err) } creds, err := provider.GetAzureKeyVaultCredentials() if err != nil { - return nil, fmt.Errorf("failed to get credentials for azure key vault, error: %w", err) + return nil, "", fmt.Errorf("failed to get credentials for azure key vault, error: %w", err) } - return creds, nil + return creds, provider.GetAzureKeyVaultDNSSuffix(), nil } func getCredentialsAuthService(authServiceAddress string, authServiceValidationAddress string, clientCertDir string) (azure.LegacyTokenCredential, error) { diff --git a/cmd/azure-keyvault-env/main.go b/cmd/azure-keyvault-env/main.go index 5e4cd850..4bbf7bce 100644 --- a/cmd/azure-keyvault-env/main.go +++ b/cmd/azure-keyvault-env/main.go @@ -30,6 +30,7 @@ import ( "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/akv2k8s" "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/akv2k8s/transformers" "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/azure" + "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/azure/credentialprovider" vault "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/azure/keyvault/client" akv "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/k8s/apis/azurekeyvault/v2beta1" clientset "github.com/SparebankenVest/azure-key-vault-to-kubernetes/pkg/k8s/client/clientset/versioned" @@ -250,21 +251,28 @@ func main() { } var creds azure.LegacyTokenCredential + var keyVaultDNSSuffix string if config.useAuthService { + provider, err := credentialprovider.NewFromEnvironment() + if err != nil { + klog.ErrorS(err, "failed to get provider from environment", "failedTimes", config.retryTimes) + os.Exit(1) + } + keyVaultDNSSuffix = provider.GetAzureKeyVaultDNSSuffix() creds, err = getCredentialsAuthService(config.authServiceAddress, config.authServiceValidationAddress, config.clientCertDir) if err != nil { klog.ErrorS(err, "failed to get credentials", "failedTimes", config.retryTimes) os.Exit(1) } } else { - creds, err = getCredentials() + creds, keyVaultDNSSuffix, err = getCredentials() if err != nil { klog.ErrorS(err, "failed to get credentials", "failedTimes", config.retryTimes) os.Exit(1) } } - vaultService := vault.NewService(creds) + vaultService := vault.NewService(creds, keyVaultDNSSuffix) klog.V(4).InfoS("reading azurekeyvaultsecret's referenced in env variables") cfg, err := rest.InClusterConfig() diff --git a/pkg/azure/credentialprovider/akv.go b/pkg/azure/credentialprovider/akv.go index 6a860829..daf66f7d 100644 --- a/pkg/azure/credentialprovider/akv.go +++ b/pkg/azure/credentialprovider/akv.go @@ -84,6 +84,11 @@ func (c CloudConfigCredentialProvider) GetAzureKeyVaultCredentials() (azure.Lega return azure.NewLegacyTokenCredentialAdal(token), nil } +// GetAzureKeyVaultDNSSuffix returns the environment specific Azure Key Vault DNS suffix +func (c CloudConfigCredentialProvider) GetAzureKeyVaultDNSSuffix() string { + return c.environment.KeyVaultDNSSuffix +} + // GetAzureKeyVaultCredentials will get Azure credentials func (c EnvironmentCredentialProvider) GetAzureKeyVaultCredentials() (azure.LegacyTokenCredential, error) { azureToken, err := getCredentials(c.envSettings, c.envSettings.Environment.ResourceIdentifiers.KeyVault) @@ -92,7 +97,11 @@ func (c EnvironmentCredentialProvider) GetAzureKeyVaultCredentials() (azure.Lega } return azure.NewLegacyTokenCredentialAdal(azureToken.token), nil +} +// GetAzureKeyVaultDNSSuffix returns the environment specific Azure Key Vault DNS suffix +func (c EnvironmentCredentialProvider) GetAzureKeyVaultDNSSuffix() string { + return c.envSettings.Environment.KeyVaultDNSSuffix } func getCredentialsAzidentity() (azure.LegacyTokenCredential, error) { @@ -108,3 +117,8 @@ func getCredentialsAzidentity() (azure.LegacyTokenCredential, error) { func (c AzidentityCredentialProvider) GetAzureKeyVaultCredentials() (azure.LegacyTokenCredential, error) { return getCredentialsAzidentity() } + +// GetAzureKeyVaultDNSSuffix returns the environment specific Azure Key Vault DNS suffix +func (c AzidentityCredentialProvider) GetAzureKeyVaultDNSSuffix() string { + return c.envSettings.Environment.KeyVaultDNSSuffix +} diff --git a/pkg/azure/credentialprovider/provider.go b/pkg/azure/credentialprovider/provider.go index ffa6616e..4aad0998 100644 --- a/pkg/azure/credentialprovider/provider.go +++ b/pkg/azure/credentialprovider/provider.go @@ -62,6 +62,7 @@ type CredentialProvider interface { GetAzureKeyVaultCredentials() (myazure.LegacyTokenCredential, error) GetAcrCredentials(image string) (k8sCredentialProvider.DockerConfigEntry, error) // IsAcrRegistry(image string) bool + GetAzureKeyVaultDNSSuffix() string } // UserAssignedManagedIdentityProvider provides credentials for Azure using managed identity diff --git a/pkg/azure/keyvault/client/service.go b/pkg/azure/keyvault/client/service.go index 28d1a4a5..8fc8ce95 100644 --- a/pkg/azure/keyvault/client/service.go +++ b/pkg/azure/keyvault/client/service.go @@ -48,18 +48,24 @@ type CertificateOptions struct { } type azureKeyVaultService struct { - credentials azure.LegacyTokenCredential + credentials azure.LegacyTokenCredential + keyVaultDNSSuffix string } // NewService creates a new AzureKeyVaultService -func NewService(creds azure.LegacyTokenCredential) Service { +func NewService(creds azure.LegacyTokenCredential, keyVaultDNSSuffix string) Service { return &azureKeyVaultService{ - credentials: creds, + credentials: creds, + keyVaultDNSSuffix: keyVaultDNSSuffix, } } -func vaultNameToURL(name string) string { - return fmt.Sprintf("https://%s.vault.azure.net", name) +func (a *azureKeyVaultService) vaultNameToURL(name string) string { + suffix := a.keyVaultDNSSuffix + if suffix == "" { + suffix = "vault.azure.net" + } + return fmt.Sprintf("https://%s.%s", name, suffix) } // GetSecret download secrets from Azure Key Vault @@ -68,7 +74,7 @@ func (a *azureKeyVaultService) GetSecret(vaultSpec *akvs.AzureKeyVault) (string, return "", fmt.Errorf("azurekeyvaultsecret.spec.vault.object.name not set") } - client, err := azsecrets.NewClient(vaultNameToURL(vaultSpec.Name), a.credentials, nil) + client, err := azsecrets.NewClient(a.vaultNameToURL(vaultSpec.Name), a.credentials, nil) if err != nil { return "", err } @@ -88,7 +94,7 @@ func (a *azureKeyVaultService) GetKey(vaultSpec *akvs.AzureKeyVault) (string, er return "", fmt.Errorf("azurekeyvaultsecret.spec.vault.object.name not set") } - client, err := azkeys.NewClient(vaultNameToURL(vaultSpec.Name), a.credentials, nil) + client, err := azkeys.NewClient(a.vaultNameToURL(vaultSpec.Name), a.credentials, nil) if err != nil { return "", err } @@ -107,11 +113,11 @@ func (a *azureKeyVaultService) GetKey(vaultSpec *akvs.AzureKeyVault) (string, er // GetCertificate download public/private certificates from Azure Key Vault func (a *azureKeyVaultService) GetCertificate(vaultSpec *akvs.AzureKeyVault, options *CertificateOptions) (*Certificate, error) { - client, err := azcertificates.NewClient(vaultNameToURL(vaultSpec.Name), a.credentials, &azcertificates.ClientOptions{}) + client, err := azcertificates.NewClient(a.vaultNameToURL(vaultSpec.Name), a.credentials, &azcertificates.ClientOptions{}) if err != nil { return nil, err } - clientSecret, err := azsecrets.NewClient(vaultNameToURL(vaultSpec.Name), a.credentials, &azsecrets.ClientOptions{}) + clientSecret, err := azsecrets.NewClient(a.vaultNameToURL(vaultSpec.Name), a.credentials, &azsecrets.ClientOptions{}) if err != nil { return nil, err } diff --git a/pkg/azure/keyvault/client/service_test.go b/pkg/azure/keyvault/client/service_test.go index 81d0a776..4b4233b6 100644 --- a/pkg/azure/keyvault/client/service_test.go +++ b/pkg/azure/keyvault/client/service_test.go @@ -88,7 +88,7 @@ func TestIntegrationGetSecret(t *testing.T) { t.Error(err) } - srvc := NewService(creds) + srvc := NewService(creds, provider.GetAzureKeyVaultDNSSuffix()) akvSecret := newAzureKeyVaultSecret("mySecret", "akv2k8s-test", "my-secret") secret, err := srvc.GetSecret(&akvSecret.Spec.Vault) @@ -115,7 +115,7 @@ func TestIntegrationEnvironmentGetSecret(t *testing.T) { t.Error(err) } - srvc := NewService(creds) + srvc := NewService(creds, provider.GetAzureKeyVaultDNSSuffix()) akvSecret := newAzureKeyVaultSecret("mySecret", "akv2k8s-test", "my-secret") secret, err := srvc.GetSecret(&akvSecret.Spec.Vault)