diff --git a/util/util.go b/util/util.go index 9b309d69d..de94dc967 100644 --- a/util/util.go +++ b/util/util.go @@ -42,11 +42,16 @@ func ToStringArray(input []interface{}) []string { } func Is404(err error) bool { - return IsHTTPErrorCode(err, http.StatusNotFound) + return ErrorContainsHTTPCode(err, http.StatusNotFound) } -func IsHTTPErrorCode(err error, code int) bool { - return strings.Contains(err.Error(), fmt.Sprintf("Code: %d", code)) +func ErrorContainsHTTPCode(err error, codes ...int) bool { + for _, code := range codes { + if strings.Contains(err.Error(), fmt.Sprintf("Code: %d", code)) { + return true + } + } + return false } func CalculateConflictsWith(self string, group []string) []string { @@ -300,3 +305,22 @@ func SetResourceData(d *schema.ResourceData, data map[string]interface{}) error return nil } + +// NormalizeMountPath to be in a form valid for accessing values from api.MountOutput +func NormalizeMountPath(path string) string { + return strings.Trim(path, "/") + "/" +} + +// CheckMountEnabled in Vault, path must contain a trailing '/', +func CheckMountEnabled(client *api.Client, path string) (bool, error) { + mounts, err := client.Sys().ListMounts() + if err != nil { + return false, err + } + + if _, ok := mounts[NormalizeMountPath(path)]; !ok { + return true, nil + } + + return false, nil +} diff --git a/vault/resource_pki_secret_backend_root_cert.go b/vault/resource_pki_secret_backend_root_cert.go index f317166d3..ff6f3a2ee 100644 --- a/vault/resource_pki_secret_backend_root_cert.go +++ b/vault/resource_pki_secret_backend_root_cert.go @@ -25,10 +25,18 @@ func pkiSecretBackendRootCertResource() *schema.Resource { Update: func(data *schema.ResourceData, i interface{}) error { return nil }, - Read: func(data *schema.ResourceData, i interface{}) error { - return nil - }, + //Read: func(data *schema.ResourceData, i interface{}) error { + // return nil + //}, + Read: pkiSecretBackendRootCertRead, CustomizeDiff: func(_ context.Context, d *schema.ResourceDiff, meta interface{}) error { + key := "serial" + o, _ := d.GetChange(key) + // skip on new resource + if o.(string) == "" { + return nil + } + client := meta.(*api.Client) cert, err := getCACertificate(client, d.Get("backend").(string)) if err != nil { @@ -36,22 +44,18 @@ func pkiSecretBackendRootCertResource() *schema.Resource { } if cert != nil { - key := "serial" - cur := d.Get(key).(string) n := certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":") - if err := d.SetNew(key, n); err != nil { - return err - } - - o, _ := d.GetChange(key) - // don't force new on new resources - if o.(string) != "" && cur != n { + if d.Get(key).(string) != n { + if err := d.SetNewComputed(key); err != nil { + return err + } if err := d.ForceNew(key); err != nil { return err } } } + return nil }, @@ -323,13 +327,35 @@ func pkiSecretBackendRootCertCreate(d *schema.ResourceData, meta interface{}) er return nil } +func pkiSecretBackendRootCertRead(d *schema.ResourceData, meta interface{}) error { + if d.IsNewResource() { + return nil + } + + client := meta.(*api.Client) + path := d.Get("backend").(string) + enabled, err := util.CheckMountEnabled(client, path) + if err != nil { + log.Printf("[WARN] Failed to check if mount %q exist, preempting the read operation", path) + return nil + } + + // trigger a resource re-creation whenever the engine's mount has disappeared + if enabled { + log.Printf("[WARN] Mount %q does not exist, setting resource for re-creation", path) + d.SetId("") + } + + return nil +} + func getCACertificate(client *api.Client, mount string) (*x509.Certificate, error) { path := fmt.Sprintf("/v1/%s/ca/pem", mount) req := client.NewRequest(http.MethodGet, path) req.ClientToken = "" resp, err := client.RawRequest(req) if err != nil { - if util.IsHTTPErrorCode(err, http.StatusNotFound) || util.IsHTTPErrorCode(err, http.StatusForbidden) { + if util.ErrorContainsHTTPCode(err, http.StatusNotFound, http.StatusForbidden) { return nil, nil } return nil, err @@ -345,7 +371,6 @@ func getCACertificate(client *api.Client, mount string) (*x509.Certificate, erro return nil, err } - log.Printf("[INFO] Reading current CA") b, _ := pem.Decode(data) if b != nil { cert, err := x509.ParseCertificate(b.Bytes)