diff --git a/client/v3/concurrency/mutex.go b/client/v3/concurrency/mutex.go index c3800d6282a..7080f0b08dd 100644 --- a/client/v3/concurrency/mutex.go +++ b/client/v3/concurrency/mutex.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" pb "go.etcd.io/etcd/api/v3/etcdserverpb" @@ -27,6 +28,7 @@ import ( // ErrLocked is returned by TryLock when Mutex is already locked by another session. var ErrLocked = errors.New("mutex: Locked by another session") var ErrSessionExpired = errors.New("mutex: session is expired") +var ErrLockReleased = errors.New("mutex: lock has already been released") // Mutex implements the sync Locker interface with etcd type Mutex struct { @@ -128,6 +130,14 @@ func (m *Mutex) tryAcquire(ctx context.Context) (*v3.TxnResponse, error) { } func (m *Mutex) Unlock(ctx context.Context) error { + if m.myKey == "" || m.myRev <= 0 || m.myKey == "\x00" { + return ErrLockReleased + } + + if !strings.HasPrefix(m.myKey, m.pfx) { + return fmt.Errorf("invalid key %q, it should have prefix %q", m.myKey, m.pfx) + } + client := m.s.Client() if _, err := client.Delete(ctx, m.myKey); err != nil { return err diff --git a/tests/integration/clientv3/concurrency/mutex_test.go b/tests/integration/clientv3/concurrency/mutex_test.go index 8220788cfb0..0ddca0e022e 100644 --- a/tests/integration/clientv3/concurrency/mutex_test.go +++ b/tests/integration/clientv3/concurrency/mutex_test.go @@ -16,6 +16,7 @@ package concurrency_test import ( "context" + "errors" "testing" "go.etcd.io/etcd/client/v3" @@ -70,3 +71,42 @@ func TestMutexLockSessionExpired(t *testing.T) { <-m2Locked } + +func TestMutexUnlock(t *testing.T) { + cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: exampleEndpoints()}) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + s1, err := concurrency.NewSession(cli) + if err != nil { + t.Fatal(err) + } + defer s1.Close() + + m1 := concurrency.NewMutex(s1, "/my-lock/") + err = m1.Unlock(context.TODO()) + if err == nil { + t.Fatal("expect lock released error") + } + if !errors.Is(err, concurrency.ErrLockReleased) { + t.Fatal(err) + } + + if err := m1.Lock(context.TODO()); err != nil { + t.Fatal(err) + } + + if err := m1.Unlock(context.TODO()); err != nil { + t.Fatal(err) + } + + err = m1.Unlock(context.TODO()) + if err == nil { + t.Fatal("expect lock released error") + } + if !errors.Is(err, concurrency.ErrLockReleased) { + t.Fatal(err) + } +}