diff --git a/pkg/lock/local/local.go b/pkg/lock/local/local.go index 784681a..86b2851 100644 --- a/pkg/lock/local/local.go +++ b/pkg/lock/local/local.go @@ -1,6 +1,7 @@ package local import ( + "fmt" "sync" "github.com/nimbolus/terraform-backend/pkg/terraform" @@ -63,3 +64,15 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) { return true, nil } + +func (l *Lock) GetLock(s *terraform.State) ([]byte, error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + lock, ok := l.db[s.ID] + if !ok { + return nil, fmt.Errorf("no lock found for state %s", s.ID) + } + + return lock, nil +} diff --git a/pkg/lock/locker.go b/pkg/lock/locker.go index 4cc9224..dd43e8f 100644 --- a/pkg/lock/locker.go +++ b/pkg/lock/locker.go @@ -8,4 +8,5 @@ type Locker interface { GetName() string Lock(s *terraform.State) (ok bool, err error) Unlock(s *terraform.State) (ok bool, err error) + GetLock(s *terraform.State) ([]byte, error) } diff --git a/pkg/lock/postgres/postgres.go b/pkg/lock/postgres/postgres.go index d090b97..06e0fd2 100644 --- a/pkg/lock/postgres/postgres.go +++ b/pkg/lock/postgres/postgres.go @@ -130,3 +130,16 @@ func (l *Lock) Unlock(s *terraform.State) (bool, error) { return true, nil } + +func (l *Lock) GetLock(s *terraform.State) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + var lock []byte + + if err := l.db.QueryRowContext(ctx, `SELECT lock_data FROM `+l.table+` WHERE state_id = $1`, s.ID).Scan(&lock); err != nil { + return nil, err + } + + return lock, nil +} diff --git a/pkg/lock/redis/redis.go b/pkg/lock/redis/redis.go index 277ab8e..4ec3eab 100644 --- a/pkg/lock/redis/redis.go +++ b/pkg/lock/redis/redis.go @@ -134,6 +134,32 @@ func (r *Lock) Unlock(s *terraform.State) (unlocked bool, err error) { return true, nil } +func (r *Lock) GetLock(s *terraform.State) (lock []byte, err error) { + mutex := r.client.NewMutex(lockKey, redsync.WithExpiry(12*time.Hour), redsync.WithTries(1), redsync.WithGenValueFunc(func() (string, error) { + return uuid.New().String(), nil + })) + + // lock the global redis mutex + if err := mutex.Lock(); err != nil { + log.Errorf("failed to lock redsync mutex: %v", err) + + return nil, err + } + + defer func() { + // unlock the global redis mutex + if _, mutErr := mutex.Unlock(); mutErr != nil { + log.Errorf("failed to unlock redsync mutex: %v", mutErr) + + if err != nil { + err = multierr.Append(err, mutErr) + } + } + }() + + return r.getLock(s) +} + func (r *Lock) setLock(s *terraform.State) error { ctx := context.Background() diff --git a/pkg/lock/util/locktest.go b/pkg/lock/util/locktest.go index 0f5810d..91c2166 100644 --- a/pkg/lock/util/locktest.go +++ b/pkg/lock/util/locktest.go @@ -43,6 +43,12 @@ func LockTest(t *testing.T, l lock.Locker) { t.Error(err) } + if lock, err := l.GetLock(&s1); err != nil { + t.Error(err) + } else if string(lock) != string(s1.Lock) { + t.Errorf("lock is not equal: %s != %s", lock, s1.Lock) + } + if locked, err := l.Lock(&s1); err != nil || !locked { t.Error("should be able to lock twice from the same process") } diff --git a/pkg/server/handler.go b/pkg/server/handler.go index dce28b1..4e0322a 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" @@ -63,7 +64,7 @@ func StateHandler(store storage.Storage, locker lock.Locker, kms kms.KMS) func(h case http.MethodGet: Get(w, state, store, kms) case http.MethodPost: - Post(w, state, body, store, kms) + Post(req, w, state, body, locker, store, kms) case http.MethodDelete: Delete(w, state, store) default: @@ -128,7 +129,22 @@ func Get(w http.ResponseWriter, state *terraform.State, store storage.Storage, k HTTPResponse(w, http.StatusOK, string(state.Data)) } -func Post(w http.ResponseWriter, state *terraform.State, body []byte, store storage.Storage, kms kms.KMS) { +func Post(r *http.Request, w http.ResponseWriter, state *terraform.State, body []byte, locker lock.Locker, store storage.Storage, kms kms.KMS) { + reqLockID := r.URL.Query().Get("ID") + + lockID, err := locker.GetLock(state) + if err != nil { + log.Warnf("failed to get lock for state with id %s: %v", state.ID, err) + HTTPResponse(w, http.StatusInternalServerError, "") + return + } + + if !strings.Contains(string(lockID), fmt.Sprintf(`"ID":"%s"`, reqLockID)) { + log.Warnf("attempting to write state with wrong lock %s (expected %s)", reqLockID, lockID) + HTTPResponse(w, http.StatusBadRequest, "") + return + } + log.Debugf("save state with id %s", state.ID) data, err := kms.Encrypt(body) diff --git a/pkg/server/handler_test.go b/pkg/server/handler_test.go index c25fe38..da0d3ea 100644 --- a/pkg/server/handler_test.go +++ b/pkg/server/handler_test.go @@ -15,36 +15,83 @@ import ( "github.com/gorilla/mux" "github.com/gruntwork-io/terratest/modules/terraform" + localkms "github.com/nimbolus/terraform-backend/pkg/kms/local" locallock "github.com/nimbolus/terraform-backend/pkg/lock/local" "github.com/nimbolus/terraform-backend/pkg/storage/filesystem" ) -var terraformBinary = flag.String("tf", "terraform", "terraform binary") - -func TestServerHandler(t *testing.T) { - s := httptest.NewServer(NewStateHandler()) - defer s.Close() - - address, err := url.JoinPath(s.URL, "/state/project1/example") +func NewStateHandler(t *testing.T) http.Handler { + store, err := filesystem.NewFileSystemStorage(filepath.Join("./handler_test", "storage")) if err != nil { t.Fatal(err) } - terraformOptions := terraform.WithDefaultRetryableErrors(t, &terraform.Options{ + locker := locallock.NewLock() + + key := "x8DiIkAKRQT7cF55NQLkAZk637W3bGVOUjGeMX5ZGXY=" + kms, _ := localkms.NewKMS(key) + + r := mux.NewRouter().StrictSlash(true) + r.HandleFunc("/state/{project}/{name}", StateHandler(store, locker, kms)) + + return r +} + +var terraformBinary = flag.String("tf", "terraform", "terraform binary") + +func terraformOptions(t *testing.T, addr string) *terraform.Options { + return terraform.WithDefaultRetryableErrors(t, &terraform.Options{ TerraformDir: "./handler_test", TerraformBinary: *terraformBinary, Vars: map[string]interface{}{}, Reconfigure: true, BackendConfig: map[string]interface{}{ - "address": address, - "lock_address": address, - "unlock_address": address, + "address": addr, + "lock_address": addr, + "unlock_address": addr, "username": "basic", "password": "some-random-secret", }, - Lock: true, + LockTimeout: "200ms", + Lock: true, }) +} + +func TestServerHandler_VerifyLockOnPush(t *testing.T) { + s := httptest.NewServer(NewStateHandler(t)) + defer s.Close() + + address, err := url.JoinPath(s.URL, "/state/project1/example") + if err != nil { + t.Fatal(err) + } + + simulateLock(t, address, true) + + for _, doLock := range []bool{true, false} { + terraformOptions := terraformOptions(t, address) + terraformOptions.Lock = doLock + + _, err = terraform.InitAndApplyE(t, terraformOptions) + if err == nil { + t.Fatal("expected error") + } + + simulateLock(t, address, false) + } +} + +func TestServerHandler(t *testing.T) { + s := httptest.NewServer(NewStateHandler(t)) + defer s.Close() + + address, err := url.JoinPath(s.URL, "/state/project1/example") + if err != nil { + t.Fatal(err) + } + + terraformOptions := terraformOptions(t, address) // Clean up resources with "terraform destroy" at the end of the test. defer terraform.Destroy(t, terraformOptions) @@ -91,20 +138,3 @@ func simulateLock(t *testing.T, address string, lock bool) { t.Fatal(err) } } - -func NewStateHandler() http.Handler { - store, err := filesystem.NewFileSystemStorage(filepath.Join("./handler_test", "storage")) - if err != nil { - panic(err) - } - - locker := locallock.NewLock() - - key := "x8DiIkAKRQT7cF55NQLkAZk637W3bGVOUjGeMX5ZGXY=" - kms, _ := localkms.NewKMS(key) - - r := mux.NewRouter().StrictSlash(true) - r.HandleFunc("/state/{project}/{name}", StateHandler(store, locker, kms)) - - return r -}