From 7bb25739d293e4416babdb1f1b40f9e87d62b0d8 Mon Sep 17 00:00:00 2001 From: John Warren Date: Sat, 9 Oct 2021 05:25:22 -0400 Subject: [PATCH 01/14] Stopping point --- cmd/lock/main.go | 37 ++++ prunelocks/pruner.go | 92 +++++++++ prunelocks/pruner_test.go | 40 ++++ rules/concurrency/doc.go | 17 ++ rules/concurrency/key.go | 65 ++++++ rules/concurrency/mutex.go | 167 +++++++++++++++ rules/concurrency/session.go | 141 +++++++++++++ rules/concurrency/stm.go | 387 +++++++++++++++++++++++++++++++++++ 8 files changed, 946 insertions(+) create mode 100644 cmd/lock/main.go create mode 100644 prunelocks/pruner.go create mode 100644 prunelocks/pruner_test.go create mode 100644 rules/concurrency/doc.go create mode 100644 rules/concurrency/key.go create mode 100644 rules/concurrency/mutex.go create mode 100644 rules/concurrency/session.go create mode 100644 rules/concurrency/stm.go diff --git a/cmd/lock/main.go b/cmd/lock/main.go new file mode 100644 index 0000000..a017689 --- /dev/null +++ b/cmd/lock/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" + "go.etcd.io/etcd/clientv3" +) + +func check(err error) { + if err != nil { + panic(err.Error()) + } +} + +func main() { + cfg := clientv3.Config{Endpoints: []string{"http://127.0.0.1:2379"}} + cl, err := clientv3.New(cfg) + check(err) + session, err := concurrency.NewSession(cl) + check(err) + mutex := concurrency.NewMutex(session, "/locks/hello") + err = mutex.TryLock(context.Background()) + check(err) + fmt.Println(mutex.Key()) + time.Sleep(time.Minute) + mutex.Unlock(context.Background()) + fmt.Println("Unlocked") + time.Sleep(time.Minute) + session.Close() + fmt.Println("Session closed") + for { + time.Sleep(time.Second) + } +} diff --git a/prunelocks/pruner.go b/prunelocks/pruner.go new file mode 100644 index 0000000..8fa19a8 --- /dev/null +++ b/prunelocks/pruner.go @@ -0,0 +1,92 @@ +package prunelocks + +import ( + "context" + "strings" + "time" + + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" +) + +type lockKey struct { + createRevision int64 + firstSeen time.Time +} + +type Pruner struct { + keys map[string]lockKey + timeout time.Duration + lockPrefixes []string + client *clientv3.Client + kv clientv3.KV + lease clientv3.Lease + logger *zap.Logger +} + +func (p Pruner) checkLocks() { + ctx := context.Background() + for _, lockPrefix := range p.lockPrefixes { + p.checkLockPrefix(ctx, lockPrefix, p.logger) + } +} + +func (p Pruner) checkLockPrefix(ctx context.Context, lockPrefix string, prefixLogger *zap.Logger) { + keysRetrieved := make(map[string]bool) + resp, _ := p.kv.Get(ctx, lockPrefix, clientv3.WithPrefix()) + for _, kv := range resp.Kvs { + // There are three possibilities: + // 1. This lock was not seen before + // 2. This lock was seen but has a different create revision + // 3. This lock was seen and has the same create revision + keyString := string(kv.Key) + keysRetrieved[keyString] = true + keyLogger := prefixLogger.With(zap.String("key", keyString), zap.Int64("create_revision", kv.CreateRevision), zap.Int64("lease", kv.Lease)) + keyLogger.Info("Found lock") + var key lockKey + var found bool + // Key not seen before or seen before with different create revision + key, found = p.keys[keyString] + keyLogger = keyLogger.With(zap.Bool("found", found)) + if found { + keyLogger = keyLogger.With(zap.String("first_seen", key.firstSeen.Format(time.RFC3339)), zap.Int64("existing_create_revision", key.createRevision)) + } + if !found || kv.CreateRevision != key.createRevision { + keyLogger.Info("creating new key entry") + key = lockKey{ + createRevision: kv.CreateRevision, + firstSeen: time.Now(), + } + p.keys[keyString] = key + } + // Key seen before with same create revision + now := time.Now() + + if now.Sub(key.firstSeen) < p.timeout { + keyLogger.Info("Lock not expired") + } else { + keyLogger.Info("Lock expired; deleting key") + resp, err := p.kv.Txn(ctx).If(clientv3.Compare(clientv3.CreateRevision(keyString), "=", key.createRevision)).Then(clientv3.OpDelete(keyString)).Commit() + if err != nil { + keyLogger.Error("error deleting key", zap.Error(err)) + } else { + keyLogger.Info("deleted key", zap.Bool("succeeded", resp.Succeeded)) + if resp.Succeeded && kv.Lease != 0 { + keyLogger.Error("revoking lease") + _, err := p.lease.Revoke(ctx, clientv3.LeaseID(kv.Lease)) + if err != nil { + keyLogger.Error("error revoking lease", zap.Error(err)) + } else { + keyLogger.Info("revoked lease") + } + } + } + } + } + for keyString := range p.keys { + if strings.HasPrefix(keyString, lockPrefix) && !keysRetrieved[keyString] { + prefixLogger.Info("removing key from map", zap.String("key", keyString)) + delete(p.keys, keyString) + } + } +} diff --git a/prunelocks/pruner_test.go b/prunelocks/pruner_test.go new file mode 100644 index 0000000..d60d716 --- /dev/null +++ b/prunelocks/pruner_test.go @@ -0,0 +1,40 @@ +package prunelocks + +import ( + "testing" + "time" + + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap/zaptest" +) + +func check(err error) { + if err != nil { + panic(err.Error()) + } +} + +func Test_Blah(t *testing.T) { + // ctx := context.Background() + cfg := clientv3.Config{Endpoints: []string{"http://127.0.0.1:2379"}} + cl, err := clientv3.New(cfg) + check(err) + kv := clientv3.NewKV(cl) + // resp, err := kv.Get(ctx, "/locks", clientv3.WithPrefix()) + // check(err) + // for _, kv := range resp.Kvs { + // fmt.Printf("%v\n", kv) + // } + p := Pruner{ + keys: make(map[string]lockKey), + timeout: time.Minute, + kv: kv, + lease: clientv3.NewLease(cl), + logger: zaptest.NewLogger(t), + lockPrefixes: []string{"/locks/hello"}, + } + for i := 0; i < 10; i++ { + p.checkLocks() + time.Sleep(10 * time.Second) + } +} diff --git a/rules/concurrency/doc.go b/rules/concurrency/doc.go new file mode 100644 index 0000000..dcdbf51 --- /dev/null +++ b/rules/concurrency/doc.go @@ -0,0 +1,17 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package concurrency implements concurrency operations on top of +// etcd such as distributed locks, barriers, and elections. +package concurrency diff --git a/rules/concurrency/key.go b/rules/concurrency/key.go new file mode 100644 index 0000000..e4cf775 --- /dev/null +++ b/rules/concurrency/key.go @@ -0,0 +1,65 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package concurrency + +import ( + "context" + "fmt" + + v3 "go.etcd.io/etcd/clientv3" + pb "go.etcd.io/etcd/etcdserver/etcdserverpb" + "go.etcd.io/etcd/mvcc/mvccpb" +) + +func waitDelete(ctx context.Context, client *v3.Client, key string, rev int64) error { + cctx, cancel := context.WithCancel(ctx) + defer cancel() + + var wr v3.WatchResponse + wch := client.Watch(cctx, key, v3.WithRev(rev)) + for wr = range wch { + for _, ev := range wr.Events { + if ev.Type == mvccpb.DELETE { + return nil + } + } + } + if err := wr.Err(); err != nil { + return err + } + if err := ctx.Err(); err != nil { + return err + } + return fmt.Errorf("lost watcher waiting for delete") +} + +// waitDeletes efficiently waits until all keys matching the prefix and no greater +// than the create revision. +func waitDeletes(ctx context.Context, client *v3.Client, pfx string, maxCreateRev int64) (*pb.ResponseHeader, error) { + getOpts := append(v3.WithLastCreate(), v3.WithMaxCreateRev(maxCreateRev)) + for { + resp, err := client.Get(ctx, pfx, getOpts...) + if err != nil { + return nil, err + } + if len(resp.Kvs) == 0 { + return resp.Header, nil + } + lastKey := string(resp.Kvs[0].Key) + if err = waitDelete(ctx, client, lastKey, resp.Header.Revision); err != nil { + return nil, err + } + } +} diff --git a/rules/concurrency/mutex.go b/rules/concurrency/mutex.go new file mode 100644 index 0000000..50a87d8 --- /dev/null +++ b/rules/concurrency/mutex.go @@ -0,0 +1,167 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package concurrency + +import ( + "context" + "errors" + "fmt" + "sync" + + v3 "go.etcd.io/etcd/clientv3" + pb "go.etcd.io/etcd/etcdserver/etcdserverpb" +) + +// ErrLocked is returned by TryLock when Mutex is already locked by another session. +var ErrLocked = errors.New("mutex: Locked by another session") +var ErrSessionExpired = errors.New("mutex: session is expired") + +// Mutex implements the sync Locker interface with etcd +type Mutex struct { + s *Session + + pfx string + myKey string + myRev int64 + hdr *pb.ResponseHeader +} + +func NewMutex(s *Session, pfx string) *Mutex { + return &Mutex{s, pfx + "/", "", -1, nil} +} + +// TryLock locks the mutex if not already locked by another session. +// If lock is held by another session, return immediately after attempting necessary cleanup +// The ctx argument is used for the sending/receiving Txn RPC. +func (m *Mutex) TryLock(ctx context.Context) error { + resp, err := m.tryAcquire(ctx) + if err != nil { + return err + } + // if no key on prefix / the minimum rev is key, already hold the lock + ownerKey := resp.Responses[1].GetResponseRange().Kvs + if len(ownerKey) == 0 || ownerKey[0].CreateRevision == m.myRev { + m.hdr = resp.Header + return nil + } + client := m.s.Client() + // Cannot lock, so delete the key + if _, err := client.Delete(ctx, m.myKey); err != nil { + return err + } + m.myKey = "\x00" + m.myRev = -1 + return ErrLocked +} + +// Lock locks the mutex with a cancelable context. If the context is canceled +// while trying to acquire the lock, the mutex tries to clean its stale lock entry. +func (m *Mutex) Lock(ctx context.Context) error { + resp, err := m.tryAcquire(ctx) + if err != nil { + return err + } + // if no key on prefix / the minimum rev is key, already hold the lock + ownerKey := resp.Responses[1].GetResponseRange().Kvs + if len(ownerKey) == 0 || ownerKey[0].CreateRevision == m.myRev { + m.hdr = resp.Header + return nil + } + client := m.s.Client() + // wait for deletion revisions prior to myKey + // TODO: early termination if the session key is deleted before other session keys with smaller revisions. + _, werr := waitDeletes(ctx, client, m.pfx, m.myRev-1) + // release lock key if wait failed + if werr != nil { + m.Unlock(client.Ctx()) + return werr + } + + // make sure the session is not expired, and the owner key still exists. + gresp, werr := client.Get(ctx, m.myKey) + if werr != nil { + m.Unlock(client.Ctx()) + return werr + } + + if len(gresp.Kvs) == 0 { // is the session key lost? + return ErrSessionExpired + } + m.hdr = gresp.Header + + return nil +} + +func (m *Mutex) tryAcquire(ctx context.Context) (*v3.TxnResponse, error) { + s := m.s + client := m.s.Client() + + m.myKey = fmt.Sprintf("%s%x", m.pfx, s.Lease()) + cmp := v3.Compare(v3.CreateRevision(m.myKey), "=", 0) + // put self in lock waiters via myKey; oldest waiter holds lock + put := v3.OpPut(m.myKey, "", v3.WithLease(s.Lease())) + // reuse key in case this session already holds the lock + get := v3.OpGet(m.myKey) + // fetch current holder to complete uncontended path with only one RPC + getOwner := v3.OpGet(m.pfx, v3.WithFirstCreate()...) + resp, err := client.Txn(ctx).If(cmp).Then(put, getOwner).Else(get, getOwner).Commit() + if err != nil { + return nil, err + } + m.myRev = resp.Header.Revision + if !resp.Succeeded { + m.myRev = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision + } + return resp, nil +} + +func (m *Mutex) Unlock(ctx context.Context) error { + client := m.s.Client() + if _, err := client.Delete(ctx, m.myKey); err != nil { + return err + } + m.myKey = "\x00" + m.myRev = -1 + return nil +} + +func (m *Mutex) IsOwner() v3.Cmp { + return v3.Compare(v3.CreateRevision(m.myKey), "=", m.myRev) +} + +func (m *Mutex) Key() string { return m.myKey } + +// Header is the response header received from etcd on acquiring the lock. +func (m *Mutex) Header() *pb.ResponseHeader { return m.hdr } + +type lockerMutex struct{ *Mutex } + +func (lm *lockerMutex) Lock() { + client := lm.s.Client() + if err := lm.Mutex.Lock(client.Ctx()); err != nil { + panic(err) + } +} +func (lm *lockerMutex) Unlock() { + client := lm.s.Client() + if err := lm.Mutex.Unlock(client.Ctx()); err != nil { + panic(err) + } +} + +// NewLocker creates a sync.Locker backed by an etcd mutex. +func NewLocker(s *Session, pfx string) sync.Locker { + return &lockerMutex{NewMutex(s, pfx)} +} diff --git a/rules/concurrency/session.go b/rules/concurrency/session.go new file mode 100644 index 0000000..97eb763 --- /dev/null +++ b/rules/concurrency/session.go @@ -0,0 +1,141 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package concurrency + +import ( + "context" + "time" + + v3 "go.etcd.io/etcd/clientv3" +) + +const defaultSessionTTL = 60 + +// Session represents a lease kept alive for the lifetime of a client. +// Fault-tolerant applications may use sessions to reason about liveness. +type Session struct { + client *v3.Client + opts *sessionOptions + id v3.LeaseID + + cancel context.CancelFunc + donec <-chan struct{} +} + +// NewSession gets the leased session for a client. +func NewSession(client *v3.Client, opts ...SessionOption) (*Session, error) { + ops := &sessionOptions{ttl: defaultSessionTTL, ctx: client.Ctx()} + for _, opt := range opts { + opt(ops) + } + + id := ops.leaseID + if id == v3.NoLease { + resp, err := client.Grant(ops.ctx, int64(ops.ttl)) + if err != nil { + return nil, err + } + id = resp.ID + } + + ctx, cancel := context.WithCancel(ops.ctx) + keepAlive, err := client.KeepAlive(ctx, id) + if err != nil || keepAlive == nil { + cancel() + return nil, err + } + + donec := make(chan struct{}) + s := &Session{client: client, opts: ops, id: id, cancel: cancel, donec: donec} + + // keep the lease alive until client error or cancelled context + go func() { + defer close(donec) + for range keepAlive { + // eat messages until keep alive channel closes + } + }() + + return s, nil +} + +// Client is the etcd client that is attached to the session. +func (s *Session) Client() *v3.Client { + return s.client +} + +// Lease is the lease ID for keys bound to the session. +func (s *Session) Lease() v3.LeaseID { return s.id } + +// Done returns a channel that closes when the lease is orphaned, expires, or +// is otherwise no longer being refreshed. +func (s *Session) Done() <-chan struct{} { return s.donec } + +// Orphan ends the refresh for the session lease. This is useful +// in case the state of the client connection is indeterminate (revoke +// would fail) or when transferring lease ownership. +func (s *Session) Orphan() { + s.cancel() + <-s.donec +} + +// Close orphans the session and revokes the session lease. +func (s *Session) Close() error { + s.Orphan() + // if revoke takes longer than the ttl, lease is expired anyway + ctx, cancel := context.WithTimeout(s.opts.ctx, time.Duration(s.opts.ttl)*time.Second) + _, err := s.client.Revoke(ctx, s.id) + cancel() + return err +} + +type sessionOptions struct { + ttl int + leaseID v3.LeaseID + ctx context.Context +} + +// SessionOption configures Session. +type SessionOption func(*sessionOptions) + +// WithTTL configures the session's TTL in seconds. +// If TTL is <= 0, the default 60 seconds TTL will be used. +func WithTTL(ttl int) SessionOption { + return func(so *sessionOptions) { + if ttl > 0 { + so.ttl = ttl + } + } +} + +// WithLease specifies the existing leaseID to be used for the session. +// This is useful in process restart scenario, for example, to reclaim +// leadership from an election prior to restart. +func WithLease(leaseID v3.LeaseID) SessionOption { + return func(so *sessionOptions) { + so.leaseID = leaseID + } +} + +// WithContext assigns a context to the session instead of defaulting to +// using the client context. This is useful for canceling NewSession and +// Close operations immediately without having to close the client. If the +// context is canceled before Close() completes, the session's lease will be +// abandoned and left to expire instead of being revoked. +func WithContext(ctx context.Context) SessionOption { + return func(so *sessionOptions) { + so.ctx = ctx + } +} diff --git a/rules/concurrency/stm.go b/rules/concurrency/stm.go new file mode 100644 index 0000000..ee11510 --- /dev/null +++ b/rules/concurrency/stm.go @@ -0,0 +1,387 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package concurrency + +import ( + "context" + "math" + + v3 "go.etcd.io/etcd/clientv3" +) + +// STM is an interface for software transactional memory. +type STM interface { + // Get returns the value for a key and inserts the key in the txn's read set. + // If Get fails, it aborts the transaction with an error, never returning. + Get(key ...string) string + // Put adds a value for a key to the write set. + Put(key, val string, opts ...v3.OpOption) + // Rev returns the revision of a key in the read set. + Rev(key string) int64 + // Del deletes a key. + Del(key string) + + // commit attempts to apply the txn's changes to the server. + commit() *v3.TxnResponse + reset() +} + +// Isolation is an enumeration of transactional isolation levels which +// describes how transactions should interfere and conflict. +type Isolation int + +const ( + // SerializableSnapshot provides serializable isolation and also checks + // for write conflicts. + SerializableSnapshot Isolation = iota + // Serializable reads within the same transaction attempt return data + // from the at the revision of the first read. + Serializable + // RepeatableReads reads within the same transaction attempt always + // return the same data. + RepeatableReads + // ReadCommitted reads keys from any committed revision. + ReadCommitted +) + +// stmError safely passes STM errors through panic to the STM error channel. +type stmError struct{ err error } + +type stmOptions struct { + iso Isolation + ctx context.Context + prefetch []string +} + +type stmOption func(*stmOptions) + +// WithIsolation specifies the transaction isolation level. +func WithIsolation(lvl Isolation) stmOption { + return func(so *stmOptions) { so.iso = lvl } +} + +// WithAbortContext specifies the context for permanently aborting the transaction. +func WithAbortContext(ctx context.Context) stmOption { + return func(so *stmOptions) { so.ctx = ctx } +} + +// WithPrefetch is a hint to prefetch a list of keys before trying to apply. +// If an STM transaction will unconditionally fetch a set of keys, prefetching +// those keys will save the round-trip cost from requesting each key one by one +// with Get(). +func WithPrefetch(keys ...string) stmOption { + return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) } +} + +// NewSTM initiates a new STM instance, using serializable snapshot isolation by default. +func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) { + opts := &stmOptions{ctx: c.Ctx()} + for _, f := range so { + f(opts) + } + if len(opts.prefetch) != 0 { + f := apply + apply = func(s STM) error { + s.Get(opts.prefetch...) + return f(s) + } + } + return runSTM(mkSTM(c, opts), apply) +} + +func mkSTM(c *v3.Client, opts *stmOptions) STM { + switch opts.iso { + case SerializableSnapshot: + s := &stmSerializable{ + stm: stm{client: c, ctx: opts.ctx}, + prefetch: make(map[string]*v3.GetResponse), + } + s.conflicts = func() []v3.Cmp { + return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...) + } + return s + case Serializable: + s := &stmSerializable{ + stm: stm{client: c, ctx: opts.ctx}, + prefetch: make(map[string]*v3.GetResponse), + } + s.conflicts = func() []v3.Cmp { return s.rset.cmps() } + return s + case RepeatableReads: + s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} + s.conflicts = func() []v3.Cmp { return s.rset.cmps() } + return s + case ReadCommitted: + s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} + s.conflicts = func() []v3.Cmp { return nil } + return s + default: + panic("unsupported stm") + } +} + +type stmResponse struct { + resp *v3.TxnResponse + err error +} + +func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) { + outc := make(chan stmResponse, 1) + go func() { + defer func() { + if r := recover(); r != nil { + e, ok := r.(stmError) + if !ok { + // client apply panicked + panic(r) + } + outc <- stmResponse{nil, e.err} + } + }() + var out stmResponse + for { + s.reset() + if out.err = apply(s); out.err != nil { + break + } + if out.resp = s.commit(); out.resp != nil { + break + } + } + outc <- out + }() + r := <-outc + return r.resp, r.err +} + +// stm implements repeatable-read software transactional memory over etcd +type stm struct { + client *v3.Client + ctx context.Context + // rset holds read key values and revisions + rset readSet + // wset holds overwritten keys and their values + wset writeSet + // getOpts are the opts used for gets + getOpts []v3.OpOption + // conflicts computes the current conflicts on the txn + conflicts func() []v3.Cmp +} + +type stmPut struct { + val string + op v3.Op +} + +type readSet map[string]*v3.GetResponse + +func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) { + for i, resp := range txnresp.Responses { + rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange()) + } +} + +// first returns the store revision from the first fetch +func (rs readSet) first() int64 { + ret := int64(math.MaxInt64 - 1) + for _, resp := range rs { + if rev := resp.Header.Revision; rev < ret { + ret = rev + } + } + return ret +} + +// cmps guards the txn from updates to read set +func (rs readSet) cmps() []v3.Cmp { + cmps := make([]v3.Cmp, 0, len(rs)) + for k, rk := range rs { + cmps = append(cmps, isKeyCurrent(k, rk)) + } + return cmps +} + +type writeSet map[string]stmPut + +func (ws writeSet) get(keys ...string) *stmPut { + for _, key := range keys { + if wv, ok := ws[key]; ok { + return &wv + } + } + return nil +} + +// cmps returns a cmp list testing no writes have happened past rev +func (ws writeSet) cmps(rev int64) []v3.Cmp { + cmps := make([]v3.Cmp, 0, len(ws)) + for key := range ws { + cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev)) + } + return cmps +} + +// puts is the list of ops for all pending writes +func (ws writeSet) puts() []v3.Op { + puts := make([]v3.Op, 0, len(ws)) + for _, v := range ws { + puts = append(puts, v.op) + } + return puts +} + +func (s *stm) Get(keys ...string) string { + if wv := s.wset.get(keys...); wv != nil { + return wv.val + } + return respToValue(s.fetch(keys...)) +} + +func (s *stm) Put(key, val string, opts ...v3.OpOption) { + s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)} +} + +func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} } + +func (s *stm) Rev(key string) int64 { + if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 { + return resp.Kvs[0].ModRevision + } + return 0 +} + +func (s *stm) commit() *v3.TxnResponse { + txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit() + if err != nil { + panic(stmError{err}) + } + if txnresp.Succeeded { + return txnresp + } + return nil +} + +func (s *stm) fetch(keys ...string) *v3.GetResponse { + if len(keys) == 0 { + return nil + } + ops := make([]v3.Op, len(keys)) + for i, key := range keys { + if resp, ok := s.rset[key]; ok { + return resp + } + ops[i] = v3.OpGet(key, s.getOpts...) + } + txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit() + if err != nil { + panic(stmError{err}) + } + s.rset.add(keys, txnresp) + return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange()) +} + +func (s *stm) reset() { + s.rset = make(map[string]*v3.GetResponse) + s.wset = make(map[string]stmPut) +} + +type stmSerializable struct { + stm + prefetch map[string]*v3.GetResponse +} + +func (s *stmSerializable) Get(keys ...string) string { + if wv := s.wset.get(keys...); wv != nil { + return wv.val + } + firstRead := len(s.rset) == 0 + for _, key := range keys { + if resp, ok := s.prefetch[key]; ok { + delete(s.prefetch, key) + s.rset[key] = resp + } + } + resp := s.stm.fetch(keys...) + if firstRead { + // txn's base revision is defined by the first read + s.getOpts = []v3.OpOption{ + v3.WithRev(resp.Header.Revision), + v3.WithSerializable(), + } + } + return respToValue(resp) +} + +func (s *stmSerializable) Rev(key string) int64 { + s.Get(key) + return s.stm.Rev(key) +} + +func (s *stmSerializable) gets() ([]string, []v3.Op) { + keys := make([]string, 0, len(s.rset)) + ops := make([]v3.Op, 0, len(s.rset)) + for k := range s.rset { + keys = append(keys, k) + ops = append(ops, v3.OpGet(k)) + } + return keys, ops +} + +func (s *stmSerializable) commit() *v3.TxnResponse { + keys, getops := s.gets() + txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...) + // use Else to prefetch keys in case of conflict to save a round trip + txnresp, err := txn.Else(getops...).Commit() + if err != nil { + panic(stmError{err}) + } + if txnresp.Succeeded { + return txnresp + } + // load prefetch with Else data + s.rset.add(keys, txnresp) + s.prefetch = s.rset + s.getOpts = nil + return nil +} + +func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp { + if len(r.Kvs) != 0 { + return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision) + } + return v3.Compare(v3.ModRevision(k), "=", 0) +} + +func respToValue(resp *v3.GetResponse) string { + if resp == nil || len(resp.Kvs) == 0 { + return "" + } + return string(resp.Kvs[0].Value) +} + +// NewSTMRepeatable is deprecated. +func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { + return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads)) +} + +// NewSTMSerializable is deprecated. +func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { + return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable)) +} + +// NewSTMReadCommitted is deprecated. +func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { + return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted)) +} From 619b147f32913b932032c167673561ee19055690 Mon Sep 17 00:00:00 2001 From: John Warren Date: Mon, 11 Oct 2021 06:53:57 -0400 Subject: [PATCH 02/14] Builds --- cmd/lock/main.go | 74 +++++++++++++++++++++++----- rules/concurrency/session_manager.go | 63 +++++++++++++++++++++++ rules/engine.go | 8 ++- rules/engine_test.go | 8 +-- rules/int_crawler.go | 13 +++-- rules/lock.go | 47 +++++++++--------- rules/lock_test.go | 22 ++++++--- rules/options.go | 4 +- rules/worker.go | 8 +-- 9 files changed, 187 insertions(+), 60 deletions(-) create mode 100644 rules/concurrency/session_manager.go diff --git a/cmd/lock/main.go b/cmd/lock/main.go index a017689..50be87e 100644 --- a/cmd/lock/main.go +++ b/cmd/lock/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "sync" "time" "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" @@ -15,23 +16,70 @@ func check(err error) { } } +var session *concurrency.Session +var sessionMutex sync.Mutex +var sessionDone <-chan struct{} + +func manageSession(client *clientv3.Client) { + initSession(client) +} + +func initSession(client *clientv3.Client) { + sessionMutex.Lock() + defer sessionMutex.Unlock() + var err error + session, err = concurrency.NewSession(client) + check(err) + fmt.Printf("Session lease ID: %x\n", session.Lease()) + sessionDone = session.Done() + go func() { + <-sessionDone + initSession(client) + }() +} + func main() { cfg := clientv3.Config{Endpoints: []string{"http://127.0.0.1:2379"}} cl, err := clientv3.New(cfg) check(err) - session, err := concurrency.NewSession(cl) - check(err) - mutex := concurrency.NewMutex(session, "/locks/hello") - err = mutex.TryLock(context.Background()) - check(err) - fmt.Println(mutex.Key()) - time.Sleep(time.Minute) - mutex.Unlock(context.Background()) - fmt.Println("Unlocked") - time.Sleep(time.Minute) - session.Close() - fmt.Println("Session closed") + // session, err = concurrency.NewSession(cl) + // check(err) + manageSession(cl) + // mutex := concurrency.NewMutex(session, "/locks/hello") + // err = mutex.TryLock(context.Background()) + // check(err) + // fmt.Println(mutex.Key()) + // time.Sleep(time.Minute) + // mutex.Unlock(context.Background()) + // fmt.Println("Unlocked") + // time.Sleep(time.Minute) + // session.Close() + // fmt.Println("Session closed") + // d := session.Done() + // go func() { + // <-d + // fmt.Println("done") + // }() for { - time.Sleep(time.Second) + sessionMutex.Lock() + mutex := concurrency.NewMutex(session, "/locks/hello") + err = mutex.TryLock(context.Background()) + sessionMutex.Unlock() + if err != nil { + fmt.Println(err) + time.Sleep(time.Second * 10) + continue + // break + } + fmt.Println(mutex.Key()) + time.Sleep(time.Second * 3) + err = mutex.Unlock(context.Background()) + if err == nil { + fmt.Println("Unlocked") + } else { + fmt.Println(err) + // break + } + time.Sleep(time.Second * 3) } } diff --git a/rules/concurrency/session_manager.go b/rules/concurrency/session_manager.go new file mode 100644 index 0000000..91a3ff4 --- /dev/null +++ b/rules/concurrency/session_manager.go @@ -0,0 +1,63 @@ +package concurrency + +import ( + "fmt" + "sync" + "time" + + "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" +) + +type SessionManager struct { + client *clientv3.Client + logger *zap.Logger + session *Session + mutex sync.Mutex + err error +} + +// NewSessionManager creates a new session manager that will return an error if the +// attempt to get a session fails or return a session manager instance that will +// create new sessions if the existing one dies. +func NewSessionManager(client *clientv3.Client, logger *zap.Logger) (*SessionManager, error) { + sm := &SessionManager{ + client: client, + logger: logger, + } + err := sm.initSession() + return sm, err +} + +func (sm *SessionManager) initSession() error { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.session, sm.err = NewSession(sm.client) + if sm.err != nil { + sm.logger.Error("error initializing session", zap.Error(sm.err)) + return sm.err + } + sm.logger.Info("new session initialized", zap.String("lease_id", fmt.Sprintf("%x", sm.session.Lease()))) + sessionDone := sm.session.Done() + go func() { + time.Sleep(time.Minute) + // Create a new session if the session dies, most likely due to an etcd + // server issue. + <-sessionDone + err := sm.initSession() + for err != nil { + // If getting a new session fails, retry unti it succeeds. + // Attempts to get the managed session will fail quickly, which + // seems to be best alternative. + time.Sleep(time.Second * 10) + err = sm.initSession() + } + }() + return nil +} + +func (sm *SessionManager) GetSession() (*Session, error) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + return sm.session, sm.err +} diff --git a/rules/engine.go b/rules/engine.go index be1059e..fa92502 100644 --- a/rules/engine.go +++ b/rules/engine.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" "go.etcd.io/etcd/clientv3" "go.uber.org/zap" "golang.org/x/net/context" @@ -38,6 +39,7 @@ type baseEngine struct { crawlers []stoppable watchers []stoppable workers []stoppable + locker ruleLocker } type v3Engine struct { @@ -93,7 +95,7 @@ func newV3Engine(logger *zap.Logger, cl *clientv3.Client, options ...EngineOptio MetricsCollector: baseMetrics, } } - + sessionMgr, _ := concurrency.NewSessionManager(cl, logger) eng := v3Engine{ baseEngine: baseEngine{ keyProc: &keyProc, @@ -102,6 +104,7 @@ func newV3Engine(logger *zap.Logger, cl *clientv3.Client, options ...EngineOptio options: opts, ruleLockTTLs: map[int]int{}, ruleMgr: ruleMgr, + locker: newV3Locker(cl, opts.lockAcquisitionTimeout, sessionMgr.GetSession), }, keyProc: keyProc, workChannel: channel, @@ -267,7 +270,8 @@ func (e *v3Engine) Run() { e.options.lockAcquisitionTimeout, prefixSlice, e.kvWrapper, - e.options.syncDelay) + e.options.syncDelay, + e.locker) if err != nil { e.logger.Fatal("Failed to initialize crawler", zap.Error(err)) } diff --git a/rules/engine_test.go b/rules/engine_test.go index 8f18e56..61ee726 100644 --- a/rules/engine_test.go +++ b/rules/engine_test.go @@ -2,10 +2,11 @@ package rules import ( "errors" - "github.com/stretchr/testify/require" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" "go.etcd.io/etcd/clientv3" "golang.org/x/net/context" @@ -29,7 +30,7 @@ type testLocker struct { errorMsg *string } -func (tlkr *testLocker) lock(key string, ttl int) (ruleLock, error) { +func (tlkr *testLocker) lock(key string) (ruleLock, error) { if tlkr.errorMsg != nil { return nil, errors.New(*tlkr.errorMsg) } @@ -43,8 +44,9 @@ type testLock struct { channel chan bool } -func (tl *testLock) unlock() { +func (tl *testLock) unlock() error { tl.channel <- true + return nil } func TestV3EngineConstructor(t *testing.T) { diff --git a/rules/int_crawler.go b/rules/int_crawler.go index 1cc74df..8eef2f9 100644 --- a/rules/int_crawler.go +++ b/rules/int_crawler.go @@ -17,10 +17,11 @@ func newIntCrawler( logger *zap.Logger, mutex *string, mutexTTL int, - mutexTimeout int, + mutexTimeout time.Duration, prefixes []string, kvWrapper WrapKV, delay int, + locker ruleLocker, ) (crawler, error) { kv := kvWrapper(cl) api := etcdV3ReadAPI{ @@ -40,6 +41,7 @@ func newIntCrawler( kv: kv, delay: delay, rulesProcessedCount: make(map[string]int), + locker: locker, } return &c, nil } @@ -78,13 +80,14 @@ type intCrawler struct { logger *zap.Logger mutex *string mutexTTL int - mutexTimeout int + mutexTimeout time.Duration prefixes []string stopped uint32 stopping uint32 // tracks the number of times a rule is processed in a single run rulesProcessedCount map[string]int metricMutex sync.Mutex + locker ruleLocker } func (ic *intCrawler) isStopping() bool { @@ -117,9 +120,9 @@ func (ic *intCrawler) run() { } else { mutex := "/crawler/" + *ic.mutex logger.Debug("Attempting to obtain mutex", - zap.String("mutex", mutex), zap.Int("TTL", ic.mutexTTL), zap.Int("Timeout", ic.mutexTimeout)) - locker := newV3Locker(ic.cl, ic.mutexTimeout) - lock, err := locker.lock(mutex, ic.mutexTTL) + zap.String("mutex", mutex), zap.Int("TTL", ic.mutexTTL), zap.Duration("Timeout", ic.mutexTimeout)) + // locker := newV3Locker(ic.cl, ic.mutexTimeout) + lock, err := ic.locker.lock(mutex) if err != nil { logger.Debug("Could not obtain mutex; skipping crawler run", zap.Error(err)) } else { diff --git a/rules/lock.go b/rules/lock.go index 02b0b13..8c5f8e5 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -1,70 +1,67 @@ package rules import ( + "errors" "time" + "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" "go.etcd.io/etcd/clientv3" - "go.etcd.io/etcd/clientv3/concurrency" "golang.org/x/net/context" ) type ruleLocker interface { - lock(string, int) (ruleLock, error) + lock(string) (ruleLock, error) } type ruleLock interface { - unlock() + unlock() error } -func newV3Locker(cl *clientv3.Client, lockTimeout int) ruleLocker { +func newV3Locker(cl *clientv3.Client, lockTimeout time.Duration, getSessn getSession) ruleLocker { return &v3Locker{ cl: cl, + getSession: getSessn, lockTimeout: lockTimeout, } } +type getSession func() (*concurrency.Session, error) + type v3Locker struct { cl *clientv3.Client - lockTimeout int + getSession getSession + lockTimeout time.Duration } -func (v3l *v3Locker) lock(key string, ttl int) (ruleLock, error) { - return v3l.lockWithTimeout(key, ttl, v3l.lockTimeout) +func (v3l *v3Locker) lock(key string) (ruleLock, error) { + return v3l.lockWithTimeout(key, v3l.lockTimeout) } -func (v3l *v3Locker) lockWithTimeout(key string, ttl int, timeout int) (ruleLock, error) { - s, err := concurrency.NewSession(v3l.cl, concurrency.WithTTL(ttl)) +func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLock, error) { + s, err := v3l.getSession() if err != nil { return nil, err } m := concurrency.NewMutex(s, key) - ctx, cancel := context.WithTimeout(SetMethod(context.Background(), "lock"), time.Duration(timeout)*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - err = m.Lock(ctx) + err = m.TryLock(ctx) if err != nil { return nil, err } return &v3Lock{ - mutex: m, - session: s, + mutex: m, }, nil } type v3Lock struct { - mutex *concurrency.Mutex - session *concurrency.Session + mutex *concurrency.Mutex } -func (v3l *v3Lock) unlock() { +func (v3l *v3Lock) unlock() error { if v3l.mutex != nil { - // TODO: Should the timeout for this be configurable too? Or use the same value as lock? - // It's a slightly different case in that here we want to make sure the unlock - // succeeds to free it for the use of others. In the lock case we want to give up - // early if someone already has the lock. - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(5)*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - err := v3l.mutex.Unlock(ctx) - if err == nil && v3l.session != nil { - v3l.session.Close() - } + return v3l.mutex.Unlock(ctx) } + return errors.New("nil mutex") } diff --git a/rules/lock_test.go b/rules/lock_test.go index a98dc1b..c549443 100644 --- a/rules/lock_test.go +++ b/rules/lock_test.go @@ -3,21 +3,28 @@ package rules import ( "testing" + "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/clientv3" ) func TestV3Locker(t *testing.T) { cfg, cl := initV3Etcd(t) c, err := clientv3.New(cfg) - assert.NoError(t, err) + require.NoError(t, err) + session, err := concurrency.NewSession(cl) + require.NoError(t, err) + defer session.Close() + rlckr := v3Locker{ cl: cl, lockTimeout: 5, + getSession: func() (*concurrency.Session, error) { return session, nil }, } - rlck, err1 := rlckr.lock("test", 10) + rlck, err1 := rlckr.lock("test") assert.NoError(t, err1) - _, err2 := rlckr.lockWithTimeout("test", 10, 1) + _, err2 := rlckr.lockWithTimeout("test", 10) assert.Error(t, err2) rlck.unlock() @@ -25,8 +32,11 @@ func TestV3Locker(t *testing.T) { done2 := make(chan bool) go func() { - lckr := newV3Locker(c, 5) - lck, lErr := lckr.lock("test1", 10) + session, err := concurrency.NewSession(cl) + require.NoError(t, err) + defer session.Close() + lckr := newV3Locker(c, 5, func() (*concurrency.Session, error) { return session, nil }) + lck, lErr := lckr.lock("test1") assert.NoError(t, lErr) done1 <- true <-done2 @@ -35,7 +45,7 @@ func TestV3Locker(t *testing.T) { } }() <-done1 - _, err = rlckr.lock("test1", 1) + _, err = rlckr.lock("test1") assert.Error(t, err) done2 <- true } diff --git a/rules/options.go b/rules/options.go index 1797d2d..d22cbc6 100644 --- a/rules/options.go +++ b/rules/options.go @@ -58,7 +58,7 @@ type engineOptions struct { contextProvider ContextProvider keyExpansion map[string][]string lockTimeout int - lockAcquisitionTimeout int + lockAcquisitionTimeout time.Duration crawlMutex *string ruleWorkBuffer int enhancedRuleFilter bool @@ -126,7 +126,7 @@ func EngineLockTimeout(lockTimeout int) EngineOption { // wait to acquire a lock. func EngineLockAcquisitionTimeout(lockAcquisitionTimeout int) EngineOption { return engineOptionFunction(func(o *engineOptions) { - o.lockAcquisitionTimeout = lockAcquisitionTimeout + o.lockAcquisitionTimeout = time.Second * time.Duration(lockAcquisitionTimeout) }) } diff --git a/rules/worker.go b/rules/worker.go index 30f4e3c..dd13e16 100644 --- a/rules/worker.go +++ b/rules/worker.go @@ -24,17 +24,17 @@ type v3Worker struct { func newV3Worker(workerID string, engine *v3Engine) (v3Worker, error) { var api readAPI - var locker ruleLocker + // var locker ruleLocker c := engine.cl kv := engine.kvWrapper(c) - locker = newV3Locker(c, engine.options.lockAcquisitionTimeout) + // locker = newV3Locker(c, engine.options.lockAcquisitionTimeout) api = &etcdV3ReadAPI{ kV: kv, } w := v3Worker{ baseWorker: baseWorker{ api: api, - locker: locker, + locker: engine.locker, metrics: engine.metrics, workerID: workerID, done: make(chan bool, 1), @@ -85,7 +85,7 @@ func (bw *baseWorker) doWork(loggerPtr **zap.Logger, } return } - l, err2 := bw.locker.lock(lockKey, lockTTL) + l, err2 := bw.locker.lock(lockKey) if err2 != nil { logger.Debug("Failed to acquire lock", zap.String("lock_key", lockKey), zap.Error(err2)) incLockMetric(metricsInfo.method, metricsInfo.keyPattern, false) From 05dcf36bc4ce4c178d6c493b45708e75c520a4d1 Mon Sep 17 00:00:00 2001 From: John Warren Date: Mon, 11 Oct 2021 09:34:11 -0400 Subject: [PATCH 03/14] UTs passing --- rules/concurrency/session_manager.go | 1 + rules/engine.go | 11 +++++++-- rules/lock.go | 8 +++---- rules/lock_test.go | 35 ++++++++++++++++++++-------- rules/options.go | 8 +++++++ rules/worker_test.go | 7 +++++- 6 files changed, 52 insertions(+), 18 deletions(-) diff --git a/rules/concurrency/session_manager.go b/rules/concurrency/session_manager.go index 91a3ff4..d668b15 100644 --- a/rules/concurrency/session_manager.go +++ b/rules/concurrency/session_manager.go @@ -30,6 +30,7 @@ func NewSessionManager(client *clientv3.Client, logger *zap.Logger) (*SessionMan } func (sm *SessionManager) initSession() error { + sm.logger.Info("Initializing session") sm.mutex.Lock() defer sm.mutex.Unlock() sm.session, sm.err = NewSession(sm.client) diff --git a/rules/engine.go b/rules/engine.go index fa92502..6d0ae3d 100644 --- a/rules/engine.go +++ b/rules/engine.go @@ -95,7 +95,14 @@ func newV3Engine(logger *zap.Logger, cl *clientv3.Client, options ...EngineOptio MetricsCollector: baseMetrics, } } - sessionMgr, _ := concurrency.NewSessionManager(cl, logger) + getSession := opts.getSession + if getSession == nil { + sessionMgr, err := concurrency.NewSessionManager(cl, logger) + if err != nil { + logger.Fatal("error getting session", zap.Error(err)) + } + getSession = sessionMgr.GetSession + } eng := v3Engine{ baseEngine: baseEngine{ keyProc: &keyProc, @@ -104,7 +111,7 @@ func newV3Engine(logger *zap.Logger, cl *clientv3.Client, options ...EngineOptio options: opts, ruleLockTTLs: map[int]int{}, ruleMgr: ruleMgr, - locker: newV3Locker(cl, opts.lockAcquisitionTimeout, sessionMgr.GetSession), + locker: newV3Locker(cl, opts.lockAcquisitionTimeout, getSession), }, keyProc: keyProc, workChannel: channel, diff --git a/rules/lock.go b/rules/lock.go index 8c5f8e5..1a12a42 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -19,7 +19,6 @@ type ruleLock interface { func newV3Locker(cl *clientv3.Client, lockTimeout time.Duration, getSessn getSession) ruleLocker { return &v3Locker{ - cl: cl, getSession: getSessn, lockTimeout: lockTimeout, } @@ -28,7 +27,6 @@ func newV3Locker(cl *clientv3.Client, lockTimeout time.Duration, getSessn getSes type getSession func() (*concurrency.Session, error) type v3Locker struct { - cl *clientv3.Client getSession getSession lockTimeout time.Duration } @@ -42,9 +40,9 @@ func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLoc return nil, err } m := concurrency.NewMutex(s, key) - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - err = m.TryLock(ctx) + // ctx, cancel := context.WithTimeout(context.Background(), timeout) + // defer cancel() + err = m.TryLock( /*ctx*/ context.Background()) if err != nil { return nil, err } diff --git a/rules/lock_test.go b/rules/lock_test.go index c549443..5400ea1 100644 --- a/rules/lock_test.go +++ b/rules/lock_test.go @@ -2,6 +2,7 @@ package rules import ( "testing" + "time" "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" "github.com/stretchr/testify/assert" @@ -13,21 +14,35 @@ func TestV3Locker(t *testing.T) { cfg, cl := initV3Etcd(t) c, err := clientv3.New(cfg) require.NoError(t, err) - session, err := concurrency.NewSession(cl) + session1, err := concurrency.NewSession(cl) require.NoError(t, err) - defer session.Close() + defer session1.Close() - rlckr := v3Locker{ - cl: cl, - lockTimeout: 5, - getSession: func() (*concurrency.Session, error) { return session, nil }, + rlckr1 := v3Locker{ + // cl: cl, + lockTimeout: time.Minute, + getSession: func() (*concurrency.Session, error) { return session1, nil }, } - rlck, err1 := rlckr.lock("test") + rlck, err1 := rlckr1.lock("/test") assert.NoError(t, err1) - _, err2 := rlckr.lockWithTimeout("test", 10) + require.NotNil(t, rlck) + + session2, err := concurrency.NewSession(cl) + require.NoError(t, err) + defer session2.Close() + + rlckr2 := v3Locker{ + // cl: cl, + lockTimeout: time.Minute, + getSession: func() (*concurrency.Session, error) { return session2, nil }, + } + + _, err2 := rlckr2.lockWithTimeout("/test", 10*time.Second) assert.Error(t, err2) rlck.unlock() + // Verify that behavior holds across goroutines + done1 := make(chan bool) done2 := make(chan bool) @@ -36,7 +51,7 @@ func TestV3Locker(t *testing.T) { require.NoError(t, err) defer session.Close() lckr := newV3Locker(c, 5, func() (*concurrency.Session, error) { return session, nil }) - lck, lErr := lckr.lock("test1") + lck, lErr := lckr.lock("/test1") assert.NoError(t, lErr) done1 <- true <-done2 @@ -45,7 +60,7 @@ func TestV3Locker(t *testing.T) { } }() <-done1 - _, err = rlckr.lock("test1") + _, err = rlckr1.lock("/test1") assert.Error(t, err) done2 <- true } diff --git a/rules/options.go b/rules/options.go index d22cbc6..45c0f60 100644 --- a/rules/options.go +++ b/rules/options.go @@ -3,6 +3,7 @@ package rules import ( "time" + "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" "golang.org/x/net/context" ) @@ -63,6 +64,7 @@ type engineOptions struct { ruleWorkBuffer int enhancedRuleFilter bool metrics MetricsCollectorOpt + getSession func() (*concurrency.Session, error) } func makeEngineOptions(options ...EngineOption) engineOptions { @@ -145,6 +147,12 @@ func EngineWatchTimeout(watchTimeout int) EngineOption { }) } +func EngineGetSession(getSession func() (*concurrency.Session, error)) EngineOption { + return engineOptionFunction(func(o *engineOptions) { + o.getSession = getSession + }) +} + // KeyExpansion enables attributes in rules to be fixed at run time // while allowing the rule declarations to continue to use the // attribute placeholders. For instance, an application may diff --git a/rules/worker_test.go b/rules/worker_test.go index addba7e..a17c4ff 100644 --- a/rules/worker_test.go +++ b/rules/worker_test.go @@ -1,8 +1,10 @@ package rules import ( + "fmt" "testing" + "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" "github.com/stretchr/testify/assert" "go.etcd.io/etcd/clientv3" "go.uber.org/zap" @@ -19,7 +21,9 @@ func TestWorkerSingleRun(t *testing.T) { metrics.SetLogger(lgr) cl, err := clientv3.New(conf) assert.NoError(t, err) - e := newV3Engine(getTestLogger(), cl, EngineLockTimeout(300)) + e := newV3Engine(getTestLogger(), cl, EngineLockTimeout(300), EngineGetSession(func() (*concurrency.Session, error) { + return nil, nil + })) channel := e.workChannel lockChannel := make(chan bool) locker := testLocker{ @@ -68,6 +72,7 @@ func TestWorkerSingleRun(t *testing.T) { expectedIncLockMetricsPatterns := []string{"/test/item"} expectedIncLockMetricsLockSuccess := []bool{true} + fmt.Println("Calling single run") go w.singleRun() channel <- rw assert.True(t, <-cbChannel) From c7464feaa7c9c8074c5b4a35bee80f5631f19b33 Mon Sep 17 00:00:00 2001 From: John Warren Date: Tue, 12 Oct 2021 09:19:49 -0400 Subject: [PATCH 04/14] Add local locker since etcd locks now reentrant --- rules/lock.go | 111 ++++++++++++++++++++++++++++++++++++++++++--- rules/lock_test.go | 11 +++-- rules/options.go | 2 +- 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/rules/lock.go b/rules/lock.go index 1a12a42..ddabfe7 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -2,6 +2,8 @@ package rules import ( "errors" + "fmt" + "sync" "time" "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" @@ -18,10 +20,90 @@ type ruleLock interface { } func newV3Locker(cl *clientv3.Client, lockTimeout time.Duration, getSessn getSession) ruleLocker { - return &v3Locker{ + locker := &v3Locker{ getSession: getSessn, lockTimeout: lockTimeout, + lLocker: newLocalLocker(), } + return locker +} + +type localLockItem struct { + // The key to lock + key string + // When lock is true the request is to lock, otherwise it is to unlock + lock bool + // true is sent in the response channel if the operator was successful + // unlocks are always successful. + response chan<- bool +} + +type localLocker struct { + once sync.Once + stopCh chan struct{} + lockLocal chan localLockItem +} + +func (ll localLocker) close() { + ll.once.Do(func() { + // This is thread safe because no goroutine is writing + // to this channel. + close(ll.stopCh) + }) +} + +func (ll localLocker) toggle(key string, lock bool) bool { + fmt.Println("***toggle called", lock) + resp := make(chan bool) + item := localLockItem{ + key: key, + response: resp, + lock: lock, + } + select { + case <-ll.stopCh: + // Return false if the locker is closed. + return false + case ll.lockLocal <- item: + } + out := <-resp + fmt.Println("***Response received", out) + return out +} + +func newLocalLocker() localLocker { + locker := localLocker{ + stopCh: make(chan struct{}), + lockLocal: make(chan localLockItem), + } + // Thread safety is achieved by allowing only one goroutine to access + // this map and having it read from channels that multiple goroutines + // writing to them. + locks := make(map[string]bool) + count := 0 + go func() { + for item := range locker.lockLocal { + count++ + fmt.Println(locks, count) + fmt.Println("lockLocal", count) + // extraneous else's and continue's to make flow clearer. + if item.lock { + if locks[item.key] { + item.response <- false + continue + } else { + locks[item.key] = true + item.response <- true + continue + } + } else { + delete(locks, item.key) + item.response <- true + continue + } + } + }() + return locker } type getSession func() (*concurrency.Session, error) @@ -29,33 +111,50 @@ type getSession func() (*concurrency.Session, error) type v3Locker struct { getSession getSession lockTimeout time.Duration + lLocker localLocker } func (v3l *v3Locker) lock(key string) (ruleLock, error) { return v3l.lockWithTimeout(key, v3l.lockTimeout) } + +var errLockedLocally = errors.New("locked locally") + +// Timeout in this case means how long the client will wait to determine +// whether the lock can be obtained. This call will return immediately once +// another client is known to hold the lock. There is no waiting for the lock +// to be released. func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLock, error) { + fmt.Println("***lockWithTimeout called") + if ok := v3l.lLocker.toggle(key, true); !ok { + return nil, errLockedLocally + } s, err := v3l.getSession() if err != nil { return nil, err } m := concurrency.NewMutex(s, key) - // ctx, cancel := context.WithTimeout(context.Background(), timeout) - // defer cancel() - err = m.TryLock( /*ctx*/ context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + err = m.TryLock(ctx) if err != nil { return nil, err } return &v3Lock{ - mutex: m, + mutex: m, + locker: v3l, + key: key, }, nil } type v3Lock struct { - mutex *concurrency.Mutex + mutex *concurrency.Mutex + locker *v3Locker + key string } func (v3l *v3Lock) unlock() error { + v3l.locker.lLocker.toggle(v3l.key, false) if v3l.mutex != nil { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/rules/lock_test.go b/rules/lock_test.go index 5400ea1..ce90942 100644 --- a/rules/lock_test.go +++ b/rules/lock_test.go @@ -18,10 +18,12 @@ func TestV3Locker(t *testing.T) { require.NoError(t, err) defer session1.Close() + lLocker := newLocalLocker() + rlckr1 := v3Locker{ - // cl: cl, lockTimeout: time.Minute, getSession: func() (*concurrency.Session, error) { return session1, nil }, + lLocker: lLocker, } rlck, err1 := rlckr1.lock("/test") assert.NoError(t, err1) @@ -32,9 +34,9 @@ func TestV3Locker(t *testing.T) { defer session2.Close() rlckr2 := v3Locker{ - // cl: cl, lockTimeout: time.Minute, getSession: func() (*concurrency.Session, error) { return session2, nil }, + lLocker: lLocker, } _, err2 := rlckr2.lockWithTimeout("/test", 10*time.Second) @@ -50,7 +52,7 @@ func TestV3Locker(t *testing.T) { session, err := concurrency.NewSession(cl) require.NoError(t, err) defer session.Close() - lckr := newV3Locker(c, 5, func() (*concurrency.Session, error) { return session, nil }) + lckr := newV3Locker(c, 5*time.Second, func() (*concurrency.Session, error) { return session, nil }) lck, lErr := lckr.lock("/test1") assert.NoError(t, lErr) done1 <- true @@ -64,3 +66,6 @@ func TestV3Locker(t *testing.T) { assert.Error(t, err) done2 <- true } + +func Test_localLocker(t *testing.T) { +} diff --git a/rules/options.go b/rules/options.go index 45c0f60..bfa6d60 100644 --- a/rules/options.go +++ b/rules/options.go @@ -74,7 +74,7 @@ func makeEngineOptions(options ...EngineOption) engineOptions { contextProvider: defaultContextProvider, syncDelay: 1, lockTimeout: 30, - lockAcquisitionTimeout: 5, + lockAcquisitionTimeout: 5 * time.Second, syncInterval: 300, syncGetTimeout: 0, watchTimeout: 0, From 8c4231b65902302354123cd58165dfde656506e9 Mon Sep 17 00:00:00 2001 From: John Warren Date: Tue, 12 Oct 2021 09:40:27 -0400 Subject: [PATCH 05/14] Fix linter issues --- cmd/lock/main.go | 85 --------------------------- prunelocks/pruner.go | 1 - rules/concurrency/mutex.go | 6 +- rules/concurrency/{stm.go => stm.go_} | 0 rules/int_crawler.go | 3 +- rules/lock.go | 2 +- rules/lock_test.go | 5 +- rules/worker.go | 2 +- 8 files changed, 11 insertions(+), 93 deletions(-) delete mode 100644 cmd/lock/main.go rename rules/concurrency/{stm.go => stm.go_} (100%) diff --git a/cmd/lock/main.go b/cmd/lock/main.go deleted file mode 100644 index 50be87e..0000000 --- a/cmd/lock/main.go +++ /dev/null @@ -1,85 +0,0 @@ -package main - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" - "go.etcd.io/etcd/clientv3" -) - -func check(err error) { - if err != nil { - panic(err.Error()) - } -} - -var session *concurrency.Session -var sessionMutex sync.Mutex -var sessionDone <-chan struct{} - -func manageSession(client *clientv3.Client) { - initSession(client) -} - -func initSession(client *clientv3.Client) { - sessionMutex.Lock() - defer sessionMutex.Unlock() - var err error - session, err = concurrency.NewSession(client) - check(err) - fmt.Printf("Session lease ID: %x\n", session.Lease()) - sessionDone = session.Done() - go func() { - <-sessionDone - initSession(client) - }() -} - -func main() { - cfg := clientv3.Config{Endpoints: []string{"http://127.0.0.1:2379"}} - cl, err := clientv3.New(cfg) - check(err) - // session, err = concurrency.NewSession(cl) - // check(err) - manageSession(cl) - // mutex := concurrency.NewMutex(session, "/locks/hello") - // err = mutex.TryLock(context.Background()) - // check(err) - // fmt.Println(mutex.Key()) - // time.Sleep(time.Minute) - // mutex.Unlock(context.Background()) - // fmt.Println("Unlocked") - // time.Sleep(time.Minute) - // session.Close() - // fmt.Println("Session closed") - // d := session.Done() - // go func() { - // <-d - // fmt.Println("done") - // }() - for { - sessionMutex.Lock() - mutex := concurrency.NewMutex(session, "/locks/hello") - err = mutex.TryLock(context.Background()) - sessionMutex.Unlock() - if err != nil { - fmt.Println(err) - time.Sleep(time.Second * 10) - continue - // break - } - fmt.Println(mutex.Key()) - time.Sleep(time.Second * 3) - err = mutex.Unlock(context.Background()) - if err == nil { - fmt.Println("Unlocked") - } else { - fmt.Println(err) - // break - } - time.Sleep(time.Second * 3) - } -} diff --git a/prunelocks/pruner.go b/prunelocks/pruner.go index 8fa19a8..5330483 100644 --- a/prunelocks/pruner.go +++ b/prunelocks/pruner.go @@ -18,7 +18,6 @@ type Pruner struct { keys map[string]lockKey timeout time.Duration lockPrefixes []string - client *clientv3.Client kv clientv3.KV lease clientv3.Lease logger *zap.Logger diff --git a/rules/concurrency/mutex.go b/rules/concurrency/mutex.go index 50a87d8..780f690 100644 --- a/rules/concurrency/mutex.go +++ b/rules/concurrency/mutex.go @@ -26,6 +26,8 @@ import ( // ErrLocked is returned by TryLock when Mutex is already locked by another session. var ErrLocked = errors.New("mutex: Locked by another session") + +// ErrSessionExpired is returned by Lock when the the mutex session is expired. var ErrSessionExpired = errors.New("mutex: session is expired") // Mutex implements the sync Locker interface with etcd @@ -85,14 +87,14 @@ func (m *Mutex) Lock(ctx context.Context) error { _, werr := waitDeletes(ctx, client, m.pfx, m.myRev-1) // release lock key if wait failed if werr != nil { - m.Unlock(client.Ctx()) + _ = m.Unlock(client.Ctx()) return werr } // make sure the session is not expired, and the owner key still exists. gresp, werr := client.Get(ctx, m.myKey) if werr != nil { - m.Unlock(client.Ctx()) + _ = m.Unlock(client.Ctx()) return werr } diff --git a/rules/concurrency/stm.go b/rules/concurrency/stm.go_ similarity index 100% rename from rules/concurrency/stm.go rename to rules/concurrency/stm.go_ diff --git a/rules/int_crawler.go b/rules/int_crawler.go index 8eef2f9..906d494 100644 --- a/rules/int_crawler.go +++ b/rules/int_crawler.go @@ -127,7 +127,8 @@ func (ic *intCrawler) run() { logger.Debug("Could not obtain mutex; skipping crawler run", zap.Error(err)) } else { ic.singleRun(logger) - lock.unlock() + err := lock.unlock() + logger.Error("Error releasing lock", zap.Error(err)) } } logger.Info("Crawler run complete") diff --git a/rules/lock.go b/rules/lock.go index ddabfe7..87117e9 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -39,7 +39,7 @@ type localLockItem struct { } type localLocker struct { - once sync.Once + once *sync.Once stopCh chan struct{} lockLocal chan localLockItem } diff --git a/rules/lock_test.go b/rules/lock_test.go index ce90942..d1fab0b 100644 --- a/rules/lock_test.go +++ b/rules/lock_test.go @@ -25,6 +25,7 @@ func TestV3Locker(t *testing.T) { getSession: func() (*concurrency.Session, error) { return session1, nil }, lLocker: lLocker, } + defer lLocker.close() rlck, err1 := rlckr1.lock("/test") assert.NoError(t, err1) require.NotNil(t, rlck) @@ -41,7 +42,7 @@ func TestV3Locker(t *testing.T) { _, err2 := rlckr2.lockWithTimeout("/test", 10*time.Second) assert.Error(t, err2) - rlck.unlock() + assert.NoError(t, rlck.unlock()) // Verify that behavior holds across goroutines @@ -58,7 +59,7 @@ func TestV3Locker(t *testing.T) { done1 <- true <-done2 if lck != nil { - lck.unlock() + assert.NoError(t, lck.unlock()) } }() <-done1 diff --git a/rules/worker.go b/rules/worker.go index dd13e16..ed92463 100644 --- a/rules/worker.go +++ b/rules/worker.go @@ -94,7 +94,7 @@ func (bw *baseWorker) doWork(loggerPtr **zap.Logger, } incLockMetric(metricsInfo.method, metricsInfo.keyPattern, true) bw.metrics.IncLockMetric(metricsInfo.method, metricsInfo.keyPattern, true) - defer l.unlock() + defer func() { _ = l.unlock() }() // Check for a second time, since checking and locking // are not atomic. capi, err1 = bw.api.getCachedAPI(rule.getKeys()) From 993c4410d017b071c5cdbe2cd4e0438e85dd2df3 Mon Sep 17 00:00:00 2001 From: John Warren Date: Tue, 12 Oct 2021 09:49:13 -0400 Subject: [PATCH 06/14] More cleanup --- rules/concurrency/stm.go_ | 387 -------------------------------------- rules/lock.go | 7 +- 2 files changed, 1 insertion(+), 393 deletions(-) delete mode 100644 rules/concurrency/stm.go_ diff --git a/rules/concurrency/stm.go_ b/rules/concurrency/stm.go_ deleted file mode 100644 index ee11510..0000000 --- a/rules/concurrency/stm.go_ +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright 2016 The etcd Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package concurrency - -import ( - "context" - "math" - - v3 "go.etcd.io/etcd/clientv3" -) - -// STM is an interface for software transactional memory. -type STM interface { - // Get returns the value for a key and inserts the key in the txn's read set. - // If Get fails, it aborts the transaction with an error, never returning. - Get(key ...string) string - // Put adds a value for a key to the write set. - Put(key, val string, opts ...v3.OpOption) - // Rev returns the revision of a key in the read set. - Rev(key string) int64 - // Del deletes a key. - Del(key string) - - // commit attempts to apply the txn's changes to the server. - commit() *v3.TxnResponse - reset() -} - -// Isolation is an enumeration of transactional isolation levels which -// describes how transactions should interfere and conflict. -type Isolation int - -const ( - // SerializableSnapshot provides serializable isolation and also checks - // for write conflicts. - SerializableSnapshot Isolation = iota - // Serializable reads within the same transaction attempt return data - // from the at the revision of the first read. - Serializable - // RepeatableReads reads within the same transaction attempt always - // return the same data. - RepeatableReads - // ReadCommitted reads keys from any committed revision. - ReadCommitted -) - -// stmError safely passes STM errors through panic to the STM error channel. -type stmError struct{ err error } - -type stmOptions struct { - iso Isolation - ctx context.Context - prefetch []string -} - -type stmOption func(*stmOptions) - -// WithIsolation specifies the transaction isolation level. -func WithIsolation(lvl Isolation) stmOption { - return func(so *stmOptions) { so.iso = lvl } -} - -// WithAbortContext specifies the context for permanently aborting the transaction. -func WithAbortContext(ctx context.Context) stmOption { - return func(so *stmOptions) { so.ctx = ctx } -} - -// WithPrefetch is a hint to prefetch a list of keys before trying to apply. -// If an STM transaction will unconditionally fetch a set of keys, prefetching -// those keys will save the round-trip cost from requesting each key one by one -// with Get(). -func WithPrefetch(keys ...string) stmOption { - return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) } -} - -// NewSTM initiates a new STM instance, using serializable snapshot isolation by default. -func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) { - opts := &stmOptions{ctx: c.Ctx()} - for _, f := range so { - f(opts) - } - if len(opts.prefetch) != 0 { - f := apply - apply = func(s STM) error { - s.Get(opts.prefetch...) - return f(s) - } - } - return runSTM(mkSTM(c, opts), apply) -} - -func mkSTM(c *v3.Client, opts *stmOptions) STM { - switch opts.iso { - case SerializableSnapshot: - s := &stmSerializable{ - stm: stm{client: c, ctx: opts.ctx}, - prefetch: make(map[string]*v3.GetResponse), - } - s.conflicts = func() []v3.Cmp { - return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...) - } - return s - case Serializable: - s := &stmSerializable{ - stm: stm{client: c, ctx: opts.ctx}, - prefetch: make(map[string]*v3.GetResponse), - } - s.conflicts = func() []v3.Cmp { return s.rset.cmps() } - return s - case RepeatableReads: - s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} - s.conflicts = func() []v3.Cmp { return s.rset.cmps() } - return s - case ReadCommitted: - s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} - s.conflicts = func() []v3.Cmp { return nil } - return s - default: - panic("unsupported stm") - } -} - -type stmResponse struct { - resp *v3.TxnResponse - err error -} - -func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) { - outc := make(chan stmResponse, 1) - go func() { - defer func() { - if r := recover(); r != nil { - e, ok := r.(stmError) - if !ok { - // client apply panicked - panic(r) - } - outc <- stmResponse{nil, e.err} - } - }() - var out stmResponse - for { - s.reset() - if out.err = apply(s); out.err != nil { - break - } - if out.resp = s.commit(); out.resp != nil { - break - } - } - outc <- out - }() - r := <-outc - return r.resp, r.err -} - -// stm implements repeatable-read software transactional memory over etcd -type stm struct { - client *v3.Client - ctx context.Context - // rset holds read key values and revisions - rset readSet - // wset holds overwritten keys and their values - wset writeSet - // getOpts are the opts used for gets - getOpts []v3.OpOption - // conflicts computes the current conflicts on the txn - conflicts func() []v3.Cmp -} - -type stmPut struct { - val string - op v3.Op -} - -type readSet map[string]*v3.GetResponse - -func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) { - for i, resp := range txnresp.Responses { - rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange()) - } -} - -// first returns the store revision from the first fetch -func (rs readSet) first() int64 { - ret := int64(math.MaxInt64 - 1) - for _, resp := range rs { - if rev := resp.Header.Revision; rev < ret { - ret = rev - } - } - return ret -} - -// cmps guards the txn from updates to read set -func (rs readSet) cmps() []v3.Cmp { - cmps := make([]v3.Cmp, 0, len(rs)) - for k, rk := range rs { - cmps = append(cmps, isKeyCurrent(k, rk)) - } - return cmps -} - -type writeSet map[string]stmPut - -func (ws writeSet) get(keys ...string) *stmPut { - for _, key := range keys { - if wv, ok := ws[key]; ok { - return &wv - } - } - return nil -} - -// cmps returns a cmp list testing no writes have happened past rev -func (ws writeSet) cmps(rev int64) []v3.Cmp { - cmps := make([]v3.Cmp, 0, len(ws)) - for key := range ws { - cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev)) - } - return cmps -} - -// puts is the list of ops for all pending writes -func (ws writeSet) puts() []v3.Op { - puts := make([]v3.Op, 0, len(ws)) - for _, v := range ws { - puts = append(puts, v.op) - } - return puts -} - -func (s *stm) Get(keys ...string) string { - if wv := s.wset.get(keys...); wv != nil { - return wv.val - } - return respToValue(s.fetch(keys...)) -} - -func (s *stm) Put(key, val string, opts ...v3.OpOption) { - s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)} -} - -func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} } - -func (s *stm) Rev(key string) int64 { - if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 { - return resp.Kvs[0].ModRevision - } - return 0 -} - -func (s *stm) commit() *v3.TxnResponse { - txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit() - if err != nil { - panic(stmError{err}) - } - if txnresp.Succeeded { - return txnresp - } - return nil -} - -func (s *stm) fetch(keys ...string) *v3.GetResponse { - if len(keys) == 0 { - return nil - } - ops := make([]v3.Op, len(keys)) - for i, key := range keys { - if resp, ok := s.rset[key]; ok { - return resp - } - ops[i] = v3.OpGet(key, s.getOpts...) - } - txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit() - if err != nil { - panic(stmError{err}) - } - s.rset.add(keys, txnresp) - return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange()) -} - -func (s *stm) reset() { - s.rset = make(map[string]*v3.GetResponse) - s.wset = make(map[string]stmPut) -} - -type stmSerializable struct { - stm - prefetch map[string]*v3.GetResponse -} - -func (s *stmSerializable) Get(keys ...string) string { - if wv := s.wset.get(keys...); wv != nil { - return wv.val - } - firstRead := len(s.rset) == 0 - for _, key := range keys { - if resp, ok := s.prefetch[key]; ok { - delete(s.prefetch, key) - s.rset[key] = resp - } - } - resp := s.stm.fetch(keys...) - if firstRead { - // txn's base revision is defined by the first read - s.getOpts = []v3.OpOption{ - v3.WithRev(resp.Header.Revision), - v3.WithSerializable(), - } - } - return respToValue(resp) -} - -func (s *stmSerializable) Rev(key string) int64 { - s.Get(key) - return s.stm.Rev(key) -} - -func (s *stmSerializable) gets() ([]string, []v3.Op) { - keys := make([]string, 0, len(s.rset)) - ops := make([]v3.Op, 0, len(s.rset)) - for k := range s.rset { - keys = append(keys, k) - ops = append(ops, v3.OpGet(k)) - } - return keys, ops -} - -func (s *stmSerializable) commit() *v3.TxnResponse { - keys, getops := s.gets() - txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...) - // use Else to prefetch keys in case of conflict to save a round trip - txnresp, err := txn.Else(getops...).Commit() - if err != nil { - panic(stmError{err}) - } - if txnresp.Succeeded { - return txnresp - } - // load prefetch with Else data - s.rset.add(keys, txnresp) - s.prefetch = s.rset - s.getOpts = nil - return nil -} - -func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp { - if len(r.Kvs) != 0 { - return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision) - } - return v3.Compare(v3.ModRevision(k), "=", 0) -} - -func respToValue(resp *v3.GetResponse) string { - if resp == nil || len(resp.Kvs) == 0 { - return "" - } - return string(resp.Kvs[0].Value) -} - -// NewSTMRepeatable is deprecated. -func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { - return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads)) -} - -// NewSTMSerializable is deprecated. -func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { - return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable)) -} - -// NewSTMReadCommitted is deprecated. -func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { - return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted)) -} diff --git a/rules/lock.go b/rules/lock.go index 87117e9..9f3d8a2 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -2,7 +2,6 @@ package rules import ( "errors" - "fmt" "sync" "time" @@ -53,7 +52,6 @@ func (ll localLocker) close() { } func (ll localLocker) toggle(key string, lock bool) bool { - fmt.Println("***toggle called", lock) resp := make(chan bool) item := localLockItem{ key: key, @@ -67,7 +65,6 @@ func (ll localLocker) toggle(key string, lock bool) bool { case ll.lockLocal <- item: } out := <-resp - fmt.Println("***Response received", out) return out } @@ -75,6 +72,7 @@ func newLocalLocker() localLocker { locker := localLocker{ stopCh: make(chan struct{}), lockLocal: make(chan localLockItem), + once: new(sync.Once), } // Thread safety is achieved by allowing only one goroutine to access // this map and having it read from channels that multiple goroutines @@ -84,8 +82,6 @@ func newLocalLocker() localLocker { go func() { for item := range locker.lockLocal { count++ - fmt.Println(locks, count) - fmt.Println("lockLocal", count) // extraneous else's and continue's to make flow clearer. if item.lock { if locks[item.key] { @@ -125,7 +121,6 @@ var errLockedLocally = errors.New("locked locally") // another client is known to hold the lock. There is no waiting for the lock // to be released. func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLock, error) { - fmt.Println("***lockWithTimeout called") if ok := v3l.lLocker.toggle(key, true); !ok { return nil, errLockedLocally } From 06bb60f02305fa1b03ce155e29187e66a07e5c46 Mon Sep 17 00:00:00 2001 From: John Warren Date: Tue, 12 Oct 2021 10:05:27 -0400 Subject: [PATCH 07/14] Release local lock if obtaining etcd lock fails --- rules/lock.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rules/lock.go b/rules/lock.go index 9f3d8a2..f36f2da 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -126,6 +126,8 @@ func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLoc } s, err := v3l.getSession() if err != nil { + // Release the local lock + v3l.lLocker.toggle(key, false) return nil, err } m := concurrency.NewMutex(s, key) @@ -133,6 +135,8 @@ func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLoc defer cancel() err = m.TryLock(ctx) if err != nil { + // Release the local lock + v3l.lLocker.toggle(key, false) return nil, err } return &v3Lock{ From 3895e25d59063c6c3bda45f61c8af1d968b29e97 Mon Sep 17 00:00:00 2001 From: John Warren Date: Tue, 12 Oct 2021 15:50:59 -0400 Subject: [PATCH 08/14] Remove old jwt reference to get rid of vuln warning --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 12ac71f..e40213c 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,6 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= From c538cb6571da4920ef1246db02f09571ce8895b9 Mon Sep 17 00:00:00 2001 From: John Warren Date: Wed, 13 Oct 2021 07:48:30 -0400 Subject: [PATCH 09/14] Added nested locker --- go.sum | 1 + rules/lock.go | 4 +- rules/lock_test.go | 16 ++++++ rules/nested_lock.go | 50 +++++++++++++++++ rules/nested_lock_test.go | 113 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 rules/nested_lock.go create mode 100644 rules/nested_lock_test.go diff --git a/go.sum b/go.sum index e40213c..ed656a3 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,7 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= diff --git a/rules/lock.go b/rules/lock.go index f36f2da..14b3d6f 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -153,7 +153,9 @@ type v3Lock struct { } func (v3l *v3Lock) unlock() error { - v3l.locker.lLocker.toggle(v3l.key, false) + // Unlocking the local lock should be done last, so obtaining the same + // lock can't occur while the etcd lock is still held. + defer v3l.locker.lLocker.toggle(v3l.key, false) if v3l.mutex != nil { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/rules/lock_test.go b/rules/lock_test.go index d1fab0b..17b50c5 100644 --- a/rules/lock_test.go +++ b/rules/lock_test.go @@ -70,3 +70,19 @@ func TestV3Locker(t *testing.T) { func Test_localLocker(t *testing.T) { } + +type mockLocker struct { + lockF func(string) (ruleLock, error) +} + +func (ml mockLocker) lock(key string) (ruleLock, error) { + return ml.lockF(key) +} + +type mockLock struct { + unlockF func() error +} + +func (ml mockLock) unlock() error { + return ml.unlockF() +} diff --git a/rules/nested_lock.go b/rules/nested_lock.go new file mode 100644 index 0000000..22034cf --- /dev/null +++ b/rules/nested_lock.go @@ -0,0 +1,50 @@ +package rules + +type nestedLocker struct { + own ruleLocker + nested ruleLocker +} + +func (nl nestedLocker) lock(key string) (ruleLock, error) { + // Try to obtain own lock first, preempting attempts + // to obtain the nested (more expensive) lock if + // getting the local lock fails. + lock, err := nl.own.lock(key) + if err != nil { + return nil, err + } + // Try to obtain the nested lock + nested, err := nl.nested.lock(key) + if err != nil { + // First unlock own lock + _ = lock.unlock() + return nil, err + } + return nestedLock{ + own: lock, + nested: nested, + }, nil +} + +type nestedLock struct { + own ruleLock + nested ruleLock +} + +func (nl nestedLock) unlock() error { + // Always unlock own lock, but after + // nested lock. This prevents attempting + // to get a new instance of the nested lock + // before the own lock is cleared. If the nested + // lock persists due to an error, it should be + // cleared with separate logic. + + err := nl.nested.unlock() + ownError := nl.own.unlock() + // The nested lock is assumed to be more expensive so + // its error takes precedence. + if err == nil { + err = ownError + } + return err +} diff --git a/rules/nested_lock_test.go b/rules/nested_lock_test.go new file mode 100644 index 0000000..826f45c --- /dev/null +++ b/rules/nested_lock_test.go @@ -0,0 +1,113 @@ +package rules + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_nestedLocker_lock(t *testing.T) { + // Set up mock data for mock functions + type testLock struct { + ruleLock + val string // Just something to compare. + } + var ownUnlockCalled bool + testOwnLock := testLock{ + ruleLock: mockLock{ + unlockF: func() error { + ownUnlockCalled = true + return nil + }, + }, + val: "own", + } + testNestedLock := testLock{ + val: "nested", + } + + ownLockErr := errors.New("own lock") + nestedLockErr := errors.New("nested lock") + + testCases := []struct { + name string + + nestedCalled bool + ownUnlockCalled bool + + err error + ownLockErr error + nestedLockErr error + }{ + { + name: "ok", + nestedCalled: true, + }, + { + name: "own_error", + ownLockErr: ownLockErr, + err: ownLockErr, + }, + { + name: "nested_error", + nestedCalled: true, + ownUnlockCalled: true, + nestedLockErr: nestedLockErr, + err: nestedLockErr, + }, + { + name: "both_errors", + ownLockErr: ownLockErr, + nestedLockErr: nestedLockErr, + err: ownLockErr, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset from any previous runs + ownUnlockCalled = false + ownCalled := false + nestedCalled := false + nl := nestedLocker{ + own: mockLocker{ + lockF: func(key string) (ruleLock, error) { + assert.Equal(t, "key", key) + ownCalled = true + return testOwnLock, tc.ownLockErr + }, + }, + nested: mockLocker{ + lockF: func(key string) (ruleLock, error) { + // The own locker should have been called first + assert.True(t, ownCalled) + assert.Equal(t, "key", key) + nestedCalled = true + return testNestedLock, tc.nestedLockErr + }, + }, + } + var err error + lock, err := nl.lock("key") + assert.Equal(t, tc.nestedCalled, nestedCalled) + assert.Equal(t, tc.ownUnlockCalled, ownUnlockCalled) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + return + } + assert.NoError(t, err) + nLock, ok := lock.(nestedLock) + if assert.True(t, ok) { + getVal := func(rl ruleLock) string { + tl, ok := rl.(testLock) + if !ok { + return "" + } + return tl.val + } + assert.Equal(t, testOwnLock.val, getVal(nLock.own)) + assert.Equal(t, testNestedLock.val, getVal(nLock.nested)) + } + }) + } +} From 01befd3abdb596ac47102d9454902de75dd957cf Mon Sep 17 00:00:00 2001 From: John Warren Date: Wed, 13 Oct 2021 08:55:25 -0400 Subject: [PATCH 10/14] Stopping point --- rules/lock.go | 174 +++++++++++++++++++++----------------------- rules/lock_test.go | 7 +- rules/map_locker.go | 103 ++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 99 deletions(-) create mode 100644 rules/map_locker.go diff --git a/rules/lock.go b/rules/lock.go index 14b3d6f..c914c21 100644 --- a/rules/lock.go +++ b/rules/lock.go @@ -2,7 +2,6 @@ package rules import ( "errors" - "sync" "time" "github.com/IBM-Cloud/go-etcd-rules/rules/concurrency" @@ -19,95 +18,97 @@ type ruleLock interface { } func newV3Locker(cl *clientv3.Client, lockTimeout time.Duration, getSessn getSession) ruleLocker { - locker := &v3Locker{ - getSession: getSessn, - lockTimeout: lockTimeout, - lLocker: newLocalLocker(), + locker := nestedLocker{ + own: newMapLocker(), + nested: &v3Locker{ + getSession: getSessn, + lockTimeout: lockTimeout, + }, } return locker } -type localLockItem struct { - // The key to lock - key string - // When lock is true the request is to lock, otherwise it is to unlock - lock bool - // true is sent in the response channel if the operator was successful - // unlocks are always successful. - response chan<- bool -} - -type localLocker struct { - once *sync.Once - stopCh chan struct{} - lockLocal chan localLockItem -} - -func (ll localLocker) close() { - ll.once.Do(func() { - // This is thread safe because no goroutine is writing - // to this channel. - close(ll.stopCh) - }) -} - -func (ll localLocker) toggle(key string, lock bool) bool { - resp := make(chan bool) - item := localLockItem{ - key: key, - response: resp, - lock: lock, - } - select { - case <-ll.stopCh: - // Return false if the locker is closed. - return false - case ll.lockLocal <- item: - } - out := <-resp - return out -} - -func newLocalLocker() localLocker { - locker := localLocker{ - stopCh: make(chan struct{}), - lockLocal: make(chan localLockItem), - once: new(sync.Once), - } - // Thread safety is achieved by allowing only one goroutine to access - // this map and having it read from channels that multiple goroutines - // writing to them. - locks := make(map[string]bool) - count := 0 - go func() { - for item := range locker.lockLocal { - count++ - // extraneous else's and continue's to make flow clearer. - if item.lock { - if locks[item.key] { - item.response <- false - continue - } else { - locks[item.key] = true - item.response <- true - continue - } - } else { - delete(locks, item.key) - item.response <- true - continue - } - } - }() - return locker -} +// type localLockItem struct { +// // The key to lock +// key string +// // When lock is true the request is to lock, otherwise it is to unlock +// lock bool +// // true is sent in the response channel if the operator was successful +// // unlocks are always successful. +// response chan<- bool +// } + +// type localLocker struct { +// once *sync.Once +// stopCh chan struct{} +// lockLocal chan localLockItem +// } + +// func (ll localLocker) close() { +// ll.once.Do(func() { +// // This is thread safe because no goroutine is writing +// // to this channel. +// close(ll.stopCh) +// }) +// } + +// func (ll localLocker) toggle(key string, lock bool) bool { +// resp := make(chan bool) +// item := localLockItem{ +// key: key, +// response: resp, +// lock: lock, +// } +// select { +// case <-ll.stopCh: +// // Return false if the locker is closed. +// return false +// case ll.lockLocal <- item: +// } +// out := <-resp +// return out +// } + +// func newLocalLocker() localLocker { +// locker := localLocker{ +// stopCh: make(chan struct{}), +// lockLocal: make(chan localLockItem), +// once: new(sync.Once), +// } +// // Thread safety is achieved by allowing only one goroutine to access +// // this map and having it read from channels that multiple goroutines +// // writing to them. +// locks := make(map[string]bool) +// count := 0 +// go func() { +// for item := range locker.lockLocal { +// count++ +// // extraneous else's and continue's to make flow clearer. +// if item.lock { +// if locks[item.key] { +// item.response <- false +// continue +// } else { +// locks[item.key] = true +// item.response <- true +// continue +// } +// } else { +// delete(locks, item.key) +// item.response <- true +// continue +// } +// } +// }() +// return locker +// } type getSession func() (*concurrency.Session, error) type v3Locker struct { getSession getSession lockTimeout time.Duration - lLocker localLocker + // lLocker localLocker } func (v3l *v3Locker) lock(key string) (ruleLock, error) { @@ -121,22 +122,12 @@ var errLockedLocally = errors.New("locked locally") // another client is known to hold the lock. There is no waiting for the lock // to be released. func (v3l *v3Locker) lockWithTimeout(key string, timeout time.Duration) (ruleLock, error) { - if ok := v3l.lLocker.toggle(key, true); !ok { - return nil, errLockedLocally - } s, err := v3l.getSession() - if err != nil { - // Release the local lock - v3l.lLocker.toggle(key, false) - return nil, err - } m := concurrency.NewMutex(s, key) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() err = m.TryLock(ctx) if err != nil { - // Release the local lock - v3l.lLocker.toggle(key, false) return nil, err } return &v3Lock{ @@ -153,9 +144,6 @@ type v3Lock struct { } func (v3l *v3Lock) unlock() error { - // Unlocking the local lock should be done last, so obtaining the same - // lock can't occur while the etcd lock is still held. - defer v3l.locker.lLocker.toggle(v3l.key, false) if v3l.mutex != nil { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/rules/lock_test.go b/rules/lock_test.go index 17b50c5..b4d0480 100644 --- a/rules/lock_test.go +++ b/rules/lock_test.go @@ -10,7 +10,7 @@ import ( "go.etcd.io/etcd/clientv3" ) -func TestV3Locker(t *testing.T) { +func Test_v3Locker(t *testing.T) { cfg, cl := initV3Etcd(t) c, err := clientv3.New(cfg) require.NoError(t, err) @@ -18,14 +18,10 @@ func TestV3Locker(t *testing.T) { require.NoError(t, err) defer session1.Close() - lLocker := newLocalLocker() - rlckr1 := v3Locker{ lockTimeout: time.Minute, getSession: func() (*concurrency.Session, error) { return session1, nil }, - lLocker: lLocker, } - defer lLocker.close() rlck, err1 := rlckr1.lock("/test") assert.NoError(t, err1) require.NotNil(t, rlck) @@ -37,7 +33,6 @@ func TestV3Locker(t *testing.T) { rlckr2 := v3Locker{ lockTimeout: time.Minute, getSession: func() (*concurrency.Session, error) { return session2, nil }, - lLocker: lLocker, } _, err2 := rlckr2.lockWithTimeout("/test", 10*time.Second) diff --git a/rules/map_locker.go b/rules/map_locker.go new file mode 100644 index 0000000..7b4c28f --- /dev/null +++ b/rules/map_locker.go @@ -0,0 +1,103 @@ +package rules + +import "sync" + +type mapLocker struct { + once *sync.Once + stopCh chan struct{} + lockLocal chan mapLockItem +} + +type mapLockItem struct { + // The key to lock + key string + // When lock is true the request is to lock, otherwise it is to unlock + lock bool + // true is sent in the response channel if the operator was successful + // unlocks are always successful. + response chan<- bool +} + +func (ml mapLocker) close() { + ml.once.Do(func() { + // This is thread safe because no goroutine is writing + // to this channel. + close(ml.stopCh) + }) +} + +func (ml mapLocker) toggle(key string, lock bool) bool { + resp := make(chan bool) + item := mapLockItem{ + key: key, + response: resp, + lock: lock, + } + select { + case <-ml.stopCh: + // Return false if the locker is closed. + return false + case ml.lockLocal <- item: + } + out := <-resp + return out +} + +func (ml mapLocker) lock(key string) (ruleLock, error) { + ok := ml.toggle(key, true) + if !ok { + return nil, errLockedLocally + } + return mapLock{ + locker: ml, + key: key, + }, nil +} + +func newMapLocker() mapLocker { + locker := mapLocker{ + stopCh: make(chan struct{}), + lockLocal: make(chan mapLockItem), + once: new(sync.Once), + } + // Thread safety is achieved by allowing only one goroutine to access + // this map and having it read from a channel with multiple goroutines + // writing to it. + locks := make(map[string]bool) + count := 0 + go func() { + for item := range locker.lockLocal { + count++ + // extraneous else's and continue's to make flow clearer. + if item.lock { + // Requesting a lock + if locks[item.key] { + // Lock already obtained + item.response <- false + continue + } else { + // Lock available + locks[item.key] = true + item.response <- true + continue + } + } else { + // Requesting an unlock + delete(locks, item.key) + item.response <- true + continue + } + } + }() + return locker +} + +type mapLock struct { + locker mapLocker + key string +} + +func (ml mapLock) unlock() error { + _ = ml.locker.toggle(ml.key, false) + return nil +} From 00be1625717dccc7d906e069096ac2164bc2ff3f Mon Sep 17 00:00:00 2001 From: John Warren Date: Wed, 13 Oct 2021 09:05:20 -0400 Subject: [PATCH 11/14] Move metrics to separate package --- {rules => metrics}/prometheus.go | 29 ++++++++++++++++++--------- {rules => metrics}/prometheus_test.go | 20 +++++++++--------- rules/etcd.go | 3 ++- rules/int_crawler.go | 3 ++- rules/key_processor.go | 5 +++-- rules/worker.go | 13 ++++++------ 6 files changed, 44 insertions(+), 29 deletions(-) rename {rules => metrics}/prometheus.go (74%) rename {rules => metrics}/prometheus_test.go (81%) diff --git a/rules/prometheus.go b/metrics/prometheus.go similarity index 74% rename from rules/prometheus.go rename to metrics/prometheus.go index a78af85..912b396 100644 --- a/rules/prometheus.go +++ b/metrics/prometheus.go @@ -1,4 +1,4 @@ -package rules +package metrics import ( "strconv" @@ -72,34 +72,45 @@ func init() { prometheus.MustRegister(rulesEngineWatcherErrors) } -func incLockMetric(methodName string, pattern string, lockSucceeded bool) { +// IncLockMetric increments the lock count. +func IncLockMetric(methodName string, pattern string, lockSucceeded bool) { rulesEngineLockCount.WithLabelValues(methodName, pattern, strconv.FormatBool(lockSucceeded)).Inc() } -func incSatisfiedThenNot(methodName string, pattern string, phaseName string) { +// IncSatisfiedThenNot increments the count of a rule having initially been +// satisfied and then not satisfied, either after the initial evaluation +// or after the lock was obtained. +func IncSatisfiedThenNot(methodName string, pattern string, phaseName string) { rulesEngineSatisfiedThenNot.WithLabelValues(methodName, pattern, phaseName).Inc() } -func timesEvaluated(methodName string, ruleID string, count int) { +// TimesEvaluated sets the number of times a rule has been evaluated. +func TimesEvaluated(methodName string, ruleID string, count int) { rulesEngineEvaluations.WithLabelValues(methodName, ruleID).Set(float64(count)) } -func workerQueueWaitTime(methodName string, startTime time.Time) { +// WorkerQueueWaitTime tracks the amount of time a work item has been sitting in +// a worker queue. +func WorkerQueueWaitTime(methodName string, startTime time.Time) { rulesEngineWorkerQueueWait.WithLabelValues(methodName).Observe(float64(time.Since(startTime).Nanoseconds() / 1e6)) } -func workBufferWaitTime(methodName, pattern string, startTime time.Time) { +// WorkBufferWaitTime tracks the amount of time a work item was in the work buffer. +func WorkBufferWaitTime(methodName, pattern string, startTime time.Time) { rulesEngineWorkBufferWaitTime.WithLabelValues(methodName, pattern).Observe(float64(time.Since(startTime).Nanoseconds() / 1e6)) } -func callbackWaitTime(pattern string, startTime time.Time) { +// CallbackWaitTime tracks how much time elapsed between when the rule was evaluated and the callback called. +func CallbackWaitTime(pattern string, startTime time.Time) { rulesEngineCallbackWaitTime.WithLabelValues(pattern).Observe(float64(time.Since(startTime).Nanoseconds() / 1e6)) } -func keyProcessBufferCap(count int) { +// KeyProcessBufferCap tracks the capacity of the key processor buffer. +func KeyProcessBufferCap(count int) { rulesEngineKeyProcessBufferCap.Set(float64(count)) } -func incWatcherErrMetric(err, prefix string) { +// IncWatcherErrMetric increments the watcher error count. +func IncWatcherErrMetric(err, prefix string) { rulesEngineWatcherErrors.WithLabelValues(err, prefix).Inc() } diff --git a/rules/prometheus_test.go b/metrics/prometheus_test.go similarity index 81% rename from rules/prometheus_test.go rename to metrics/prometheus_test.go index 0fb1537..bab45ad 100644 --- a/rules/prometheus_test.go +++ b/metrics/prometheus_test.go @@ -1,4 +1,4 @@ -package rules +package metrics import ( "net/http" @@ -41,44 +41,44 @@ func checkMetrics(t *testing.T, expectedOutput string) { } func TestIncLockMetric(t *testing.T) { - incLockMetric("getKey", "/key/pattern", true) - incLockMetric("getKey", "/second/pattern", false) + IncLockMetric("getKey", "/key/pattern", true) + IncLockMetric("getKey", "/second/pattern", false) checkMetrics(t, `rules_etcd_lock_count{method="getKey",pattern="/key/pattern",success="true"} 1`) checkMetrics(t, `rules_etcd_lock_count{method="getKey",pattern="/second/pattern",success="false"} 1`) } func TestIncSatisfiedThenNot(t *testing.T) { - incSatisfiedThenNot("getKey", "/key/pattern", "phaseName") + IncSatisfiedThenNot("getKey", "/key/pattern", "phaseName") checkMetrics(t, `rules_etcd_rule_satisfied_then_not{method="getKey",pattern="/key/pattern",phase="phaseName"} 1`) } func TestTimesEvaluated(t *testing.T) { - timesEvaluated("getKey", "rule1234", 5) + TimesEvaluated("getKey", "rule1234", 5) checkMetrics(t, `rules_etcd_evaluations{method="getKey",rule="rule1234"} 5`) } func TestWokerQueueWaitTime(t *testing.T) { - workerQueueWaitTime("getKey", time.Now()) + WorkerQueueWaitTime("getKey", time.Now()) checkMetrics(t, `rules_etcd_worker_queue_wait_ms_count{method="getKey"} 1`) } func TestWorkBufferWaitTime(t *testing.T) { - workBufferWaitTime("getKey", "/desired/key/pattern", time.Now()) + WorkBufferWaitTime("getKey", "/desired/key/pattern", time.Now()) checkMetrics(t, `rules_etcd_work_buffer_wait_ms_count{method="getKey",pattern="/desired/key/pattern"} 1`) } func TestCallbackWaitTime(t *testing.T) { - callbackWaitTime("/desired/key/pattern", time.Now()) + CallbackWaitTime("/desired/key/pattern", time.Now()) checkMetrics(t, `rules_etcd_callback_wait_ms_count{pattern="/desired/key/pattern"} 1`) } func Test_keyProcessBufferCap(t *testing.T) { - keyProcessBufferCap(100) + KeyProcessBufferCap(100) checkMetrics(t, `rules_etcd_key_process_buffer_cap 100`) } func Test_incWatcherErrMetric(t *testing.T) { - incWatcherErrMetric("err", "/desired/key/prefix") + IncWatcherErrMetric("err", "/desired/key/prefix") checkMetrics(t, `rules_etcd_watcher_errors{error="err",prefix="/desired/key/prefix"} 1`) } diff --git a/rules/etcd.go b/rules/etcd.go index 3c9ec07..f66f705 100644 --- a/rules/etcd.go +++ b/rules/etcd.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/IBM-Cloud/go-etcd-rules/metrics" "go.etcd.io/etcd/mvcc/mvccpb" "go.etcd.io/etcd/clientv3" @@ -161,7 +162,7 @@ func (ev3kw *etcdV3KeyWatcher) next() (string, *string, error) { if err != nil { // There is a fixed set of possible errors. // See https://github.com/etcd-io/etcd/blob/release-3.4/clientv3/watch.go#L115-L126 - incWatcherErrMetric(err.Error(), ev3kw.prefix) + metrics.IncWatcherErrMetric(err.Error(), ev3kw.prefix) ev3kw.reset() return "", nil, err } diff --git a/rules/int_crawler.go b/rules/int_crawler.go index 1cc74df..0a73e0d 100644 --- a/rules/int_crawler.go +++ b/rules/int_crawler.go @@ -4,6 +4,7 @@ import ( "sync" "time" + "github.com/IBM-Cloud/go-etcd-rules/metrics" "go.etcd.io/etcd/clientv3" "go.uber.org/zap" "golang.org/x/net/context" @@ -167,7 +168,7 @@ func (ic *intCrawler) singleRun(logger *zap.Logger) { ic.metricMutex.Lock() defer ic.metricMutex.Unlock() for ruleID, count := range ic.rulesProcessedCount { - timesEvaluated(crawlerMethodName, ruleID, count) + metrics.TimesEvaluated(crawlerMethodName, ruleID, count) ic.metrics.TimesEvaluated(crawlerMethodName, ruleID, count) } } diff --git a/rules/key_processor.go b/rules/key_processor.go index cf18417..14e6b64 100644 --- a/rules/key_processor.go +++ b/rules/key_processor.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/IBM-Cloud/go-etcd-rules/metrics" "go.uber.org/zap" ) @@ -77,7 +78,7 @@ func (v3kp *v3KeyProcessor) dispatchWork(index int, rule staticRule, logger *zap start := time.Now() v3kp.channel <- work // measures the amount of time work is blocked from being added to the buffer - workBufferWaitTime(work.metricsInfo.method, keyPattern, start) + metrics.WorkBufferWaitTime(work.metricsInfo.method, keyPattern, start) } func newV3KeyProcessor(channel chan v3RuleWork, rm *ruleManager, kpChannel chan *keyTask, concurrency int, logger *zap.Logger) v3KeyProcessor { @@ -116,7 +117,7 @@ func (v3kp *v3KeyProcessor) processKey(key string, value *string, api readAPI, l func (v3kp *v3KeyProcessor) bufferCapacitySampler() { for { - keyProcessBufferCap(cap(v3kp.kpChannel) - len(v3kp.kpChannel)) + metrics.KeyProcessBufferCap(cap(v3kp.kpChannel) - len(v3kp.kpChannel)) time.Sleep(time.Minute) } } diff --git a/rules/worker.go b/rules/worker.go index 30f4e3c..058341f 100644 --- a/rules/worker.go +++ b/rules/worker.go @@ -4,6 +4,7 @@ import ( "sync" "time" + "github.com/IBM-Cloud/go-etcd-rules/metrics" "go.uber.org/zap" ) @@ -80,7 +81,7 @@ func (bw *baseWorker) doWork(loggerPtr **zap.Logger, } if !sat || is(&bw.stopping) { if !sat { - incSatisfiedThenNot(metricsInfo.method, metricsInfo.keyPattern, "worker.doWorkBeforeLock") + metrics.IncSatisfiedThenNot(metricsInfo.method, metricsInfo.keyPattern, "worker.doWorkBeforeLock") bw.metrics.IncSatisfiedThenNot(metricsInfo.method, metricsInfo.keyPattern, "worker.doWorkBeforeLock") } return @@ -88,11 +89,11 @@ func (bw *baseWorker) doWork(loggerPtr **zap.Logger, l, err2 := bw.locker.lock(lockKey, lockTTL) if err2 != nil { logger.Debug("Failed to acquire lock", zap.String("lock_key", lockKey), zap.Error(err2)) - incLockMetric(metricsInfo.method, metricsInfo.keyPattern, false) + metrics.IncLockMetric(metricsInfo.method, metricsInfo.keyPattern, false) bw.metrics.IncLockMetric(metricsInfo.method, metricsInfo.keyPattern, false) return } - incLockMetric(metricsInfo.method, metricsInfo.keyPattern, true) + metrics.IncLockMetric(metricsInfo.method, metricsInfo.keyPattern, true) bw.metrics.IncLockMetric(metricsInfo.method, metricsInfo.keyPattern, true) defer l.unlock() // Check for a second time, since checking and locking @@ -108,15 +109,15 @@ func (bw *baseWorker) doWork(loggerPtr **zap.Logger, return } if !sat { - incSatisfiedThenNot(metricsInfo.method, metricsInfo.keyPattern, "worker.doWorkAfterLock") + metrics.IncSatisfiedThenNot(metricsInfo.method, metricsInfo.keyPattern, "worker.doWorkAfterLock") bw.metrics.IncSatisfiedThenNot(metricsInfo.method, metricsInfo.keyPattern, "worker.doWorkAfterLock") } - workerQueueWaitTime(metricsInfo.method, metricsInfo.startTime) + metrics.WorkerQueueWaitTime(metricsInfo.method, metricsInfo.startTime) bw.metrics.WorkerQueueWaitTime(metricsInfo.method, metricsInfo.startTime) if sat && !is(&bw.stopping) { startTime := time.Now() callback() - callbackWaitTime(metricsInfo.keyPattern, startTime) + metrics.CallbackWaitTime(metricsInfo.keyPattern, startTime) } } From 50e229a11a77cc701536751b33ecab797961c8ae Mon Sep 17 00:00:00 2001 From: John Warren Date: Wed, 13 Oct 2021 10:27:00 -0400 Subject: [PATCH 12/14] Rename lock files to prevent conflicts --- rules/{lock.go => old_lock.go} | 0 rules/{lock_test.go => old_lock_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename rules/{lock.go => old_lock.go} (100%) rename rules/{lock_test.go => old_lock_test.go} (100%) diff --git a/rules/lock.go b/rules/old_lock.go similarity index 100% rename from rules/lock.go rename to rules/old_lock.go diff --git a/rules/lock_test.go b/rules/old_lock_test.go similarity index 100% rename from rules/lock_test.go rename to rules/old_lock_test.go From 6863128019f29455e46d65ae8eff6baec6de8cc3 Mon Sep 17 00:00:00 2001 From: John Warren Date: Wed, 13 Oct 2021 11:28:36 -0400 Subject: [PATCH 13/14] Stopping point --- rules/lock/lock.go | 85 +++++++++++++++++++++++++ rules/lock/lock_test.go | 43 +++++++++++++ rules/lock/map_locker.go | 110 ++++++++++++++++++++++++++++++++ rules/lock/mock.go | 44 +++++++++++++ rules/lock/nested_lock.go | 50 +++++++++++++++ rules/lock/nested_lock_test.go | 113 +++++++++++++++++++++++++++++++++ rules/teststore/etcd.go | 20 ++++++ 7 files changed, 465 insertions(+) create mode 100644 rules/lock/lock.go create mode 100644 rules/lock/lock_test.go create mode 100644 rules/lock/map_locker.go create mode 100644 rules/lock/mock.go create mode 100644 rules/lock/nested_lock.go create mode 100644 rules/lock/nested_lock_test.go create mode 100644 rules/teststore/etcd.go diff --git a/rules/lock/lock.go b/rules/lock/lock.go new file mode 100644 index 0000000..1fcf4bb --- /dev/null +++ b/rules/lock/lock.go @@ -0,0 +1,85 @@ +package lock + +import ( + "errors" + "time" + + "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/clientv3/concurrency" + "golang.org/x/net/context" +) + +type RuleLocker interface { + Lock(string, ...Option) (RuleLock, error) +} + +type RuleLock interface { + Unlock() error +} + +type options struct { + // TODO add options +} + +type Option func(lo *options) + +// NewV3Locker creates a locker backed by etcd V3. +func NewV3Locker(cl *clientv3.Client, lockTimeout int) RuleLocker { + return &v3Locker{ + cl: cl, + lockTimeout: lockTimeout, + } +} + +type v3Locker struct { + cl *clientv3.Client + lockTimeout int +} + +func (v3l *v3Locker) Lock(key string, options ...Option) (RuleLock, error) { + return v3l.lockWithTimeout(key, v3l.lockTimeout) +} +func (v3l *v3Locker) lockWithTimeout(key string, timeout int) (RuleLock, error) { + // TODO once we switch to a shared session, we can get rid of the TTL option + // and go to the default (60 seconds). This is the TTL for the lease that + // is associated with the session and the lease is renewed before it expires + // while the session is active (not closed). It is not the TTL of any locks; + // those persist until Unlock is called or the process dies and the session + // lease is allowed to expire. + s, err := concurrency.NewSession(v3l.cl, concurrency.WithTTL(30)) + if err != nil { + return nil, err + } + m := concurrency.NewMutex(s, key) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + err = m.Lock(ctx) + if err != nil { + return nil, err + } + return &v3Lock{ + mutex: m, + session: s, + }, nil +} + +type v3Lock struct { + mutex *concurrency.Mutex + session *concurrency.Session +} + +// ErrNilMutex indicates that the lock has a nil mutex +var ErrNilMutex = errors.New("mutex is nil") + +func (v3l *v3Lock) Unlock() error { + if v3l.mutex != nil { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + err := v3l.mutex.Unlock(ctx) + if err == nil && v3l.session != nil { + v3l.session.Close() + } + return err + } + return ErrNilMutex +} diff --git a/rules/lock/lock_test.go b/rules/lock/lock_test.go new file mode 100644 index 0000000..b01b045 --- /dev/null +++ b/rules/lock/lock_test.go @@ -0,0 +1,43 @@ +package lock + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.etcd.io/etcd/clientv3" + + "github.com/IBM-Cloud/go-etcd-rules/rules/teststore" +) + +func Test_V3Locker(t *testing.T) { + cfg, cl := teststore.InitV3Etcd(t) + c, err := clientv3.New(cfg) + assert.NoError(t, err) + rlckr := v3Locker{ + cl: cl, + lockTimeout: 5, + } + rlck, err1 := rlckr.Lock("test") + assert.NoError(t, err1) + _, err2 := rlckr.lockWithTimeout("test", 10) + assert.Error(t, err2) + rlck.Unlock() + + done1 := make(chan bool) + done2 := make(chan bool) + + go func() { + lckr := NewV3Locker(c, 5) + lck, lErr := lckr.Lock("test1") + assert.NoError(t, lErr) + done1 <- true + <-done2 + if lck != nil { + lck.Unlock() + } + }() + <-done1 + _, err = rlckr.Lock("test1") + assert.Error(t, err) + done2 <- true +} diff --git a/rules/lock/map_locker.go b/rules/lock/map_locker.go new file mode 100644 index 0000000..cf5f80e --- /dev/null +++ b/rules/lock/map_locker.go @@ -0,0 +1,110 @@ +package lock + +import ( + "errors" + "sync" +) + +type mapLocker struct { + once *sync.Once + stopCh chan struct{} + lockLocal chan mapLockItem +} + +// ErrLockedLocally indicates that a local goroutine holds the lock +// and no attempt will be made to obtain the lock via etcd. +var ErrLockedLocally = errors.New("locked locally") + +type mapLockItem struct { + // The key to lock + key string + // When lock is true the request is to lock, otherwise it is to unlock + lock bool + // true is sent in the response channel if the operator was successful + // unlocks are always successful. + response chan<- bool +} + +func (ml mapLocker) close() { + ml.once.Do(func() { + // This is thread safe because no goroutine is writing + // to this channel. + close(ml.stopCh) + }) +} + +func (ml mapLocker) toggle(key string, lock bool) bool { + resp := make(chan bool) + item := mapLockItem{ + key: key, + response: resp, + lock: lock, + } + select { + case <-ml.stopCh: + // Return false if the locker is closed. + return false + case ml.lockLocal <- item: + } + out := <-resp + return out +} + +func (ml mapLocker) Lock(key string) (RuleLock, error) { + ok := ml.toggle(key, true) + if !ok { + return nil, ErrLockedLocally + } + return mapLock{ + locker: ml, + key: key, + }, nil +} + +func newMapLocker() mapLocker { + locker := mapLocker{ + stopCh: make(chan struct{}), + lockLocal: make(chan mapLockItem), + once: new(sync.Once), + } + // Thread safety is achieved by allowing only one goroutine to access + // this map and having it read from a channel with multiple goroutines + // writing to it. + locks := make(map[string]bool) + count := 0 + go func() { + for item := range locker.lockLocal { + count++ + // extraneous else's and continue's to make flow clearer. + if item.lock { + // Requesting a lock + if locks[item.key] { + // Lock already obtained + item.response <- false + continue + } else { + // Lock available + locks[item.key] = true + item.response <- true + continue + } + } else { + // Requesting an unlock + delete(locks, item.key) + item.response <- true + continue + } + } + }() + return locker +} + +type mapLock struct { + locker mapLocker + key string +} + +func (ml mapLock) Unlock() error { + _ = ml.locker.toggle(ml.key, false) + return nil +} diff --git a/rules/lock/mock.go b/rules/lock/mock.go new file mode 100644 index 0000000..ade20cf --- /dev/null +++ b/rules/lock/mock.go @@ -0,0 +1,44 @@ +package lock + +import "errors" + +// MockLocker implements the RuleLocker interface. +type MockLocker struct { + Channel chan bool + ErrorMsg *string +} + +func (tlkr *MockLocker) Lock(key string, options ...Option) (RuleLock, error) { + if tlkr.ErrorMsg != nil { + return nil, errors.New(*tlkr.ErrorMsg) + } + tLock := mockLock{ + channel: tlkr.Channel, + } + return &tLock, nil +} + +type mockLock struct { + channel chan bool +} + +func (tl *mockLock) Unlock() error { + tl.channel <- true + return nil +} + +type FuncMockLocker struct { + LockF func(string) (RuleLock, error) +} + +func (ml FuncMockLocker) Lock(key string, options ...Option) (RuleLock, error) { + return ml.LockF(key) +} + +type FuncMockLock struct { + UnlockF func() error +} + +func (ml FuncMockLock) Unlock() error { + return ml.UnlockF() +} diff --git a/rules/lock/nested_lock.go b/rules/lock/nested_lock.go new file mode 100644 index 0000000..e505522 --- /dev/null +++ b/rules/lock/nested_lock.go @@ -0,0 +1,50 @@ +package lock + +type nestedLocker struct { + own RuleLocker + nested RuleLocker +} + +func (nl nestedLocker) Lock(key string) (RuleLock, error) { + // Try to obtain own lock first, preempting attempts + // to obtain the nested (more expensive) lock if + // getting the local lock fails. + lock, err := nl.own.Lock(key) + if err != nil { + return nil, err + } + // Try to obtain the nested lock + nested, err := nl.nested.Lock(key) + if err != nil { + // First unlock own lock + _ = lock.Unlock() + return nil, err + } + return nestedLock{ + own: lock, + nested: nested, + }, nil +} + +type nestedLock struct { + own RuleLock + nested RuleLock +} + +func (nl nestedLock) Unlock() error { + // Always unlock own lock, but after + // nested lock. This prevents attempting + // to get a new instance of the nested lock + // before the own lock is cleared. If the nested + // lock persists due to an error, it should be + // cleared with separate logic. + + err := nl.nested.Unlock() + ownError := nl.own.Unlock() + // The nested lock is assumed to be more expensive so + // its error takes precedence. + if err == nil { + err = ownError + } + return err +} diff --git a/rules/lock/nested_lock_test.go b/rules/lock/nested_lock_test.go new file mode 100644 index 0000000..16ec5eb --- /dev/null +++ b/rules/lock/nested_lock_test.go @@ -0,0 +1,113 @@ +package lock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_nestedLocker_lock(t *testing.T) { + // Set up mock data for mock functions + type testLock struct { + RuleLock + val string // Just something to compare. + } + var ownUnlockCalled bool + testOwnLock := testLock{ + RuleLock: FuncMockLock{ + UnlockF: func() error { + ownUnlockCalled = true + return nil + }, + }, + val: "own", + } + testNestedLock := testLock{ + val: "nested", + } + + ownLockErr := errors.New("own lock") + nestedLockErr := errors.New("nested lock") + + testCases := []struct { + name string + + nestedCalled bool + ownUnlockCalled bool + + err error + ownLockErr error + nestedLockErr error + }{ + { + name: "ok", + nestedCalled: true, + }, + { + name: "own_error", + ownLockErr: ownLockErr, + err: ownLockErr, + }, + { + name: "nested_error", + nestedCalled: true, + ownUnlockCalled: true, + nestedLockErr: nestedLockErr, + err: nestedLockErr, + }, + { + name: "both_errors", + ownLockErr: ownLockErr, + nestedLockErr: nestedLockErr, + err: ownLockErr, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset from any previous runs + ownUnlockCalled = false + ownCalled := false + nestedCalled := false + nl := nestedLocker{ + own: FuncMockLocker{ + LockF: func(key string) (RuleLock, error) { + assert.Equal(t, "key", key) + ownCalled = true + return testOwnLock, tc.ownLockErr + }, + }, + nested: FuncMockLocker{ + LockF: func(key string) (RuleLock, error) { + // The own locker should have been called first + assert.True(t, ownCalled) + assert.Equal(t, "key", key) + nestedCalled = true + return testNestedLock, tc.nestedLockErr + }, + }, + } + var err error + lock, err := nl.Lock("key") + assert.Equal(t, tc.nestedCalled, nestedCalled) + assert.Equal(t, tc.ownUnlockCalled, ownUnlockCalled) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + return + } + assert.NoError(t, err) + nLock, ok := lock.(nestedLock) + if assert.True(t, ok) { + getVal := func(rl RuleLock) string { + tl, ok := rl.(testLock) + if !ok { + return "" + } + return tl.val + } + assert.Equal(t, testOwnLock.val, getVal(nLock.own)) + assert.Equal(t, testNestedLock.val, getVal(nLock.nested)) + } + }) + } +} diff --git a/rules/teststore/etcd.go b/rules/teststore/etcd.go new file mode 100644 index 0000000..d3a0652 --- /dev/null +++ b/rules/teststore/etcd.go @@ -0,0 +1,20 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.etcd.io/etcd/clientv3" +) + +// InitV3Etcd initializes etcd for test cases +func InitV3Etcd(t *testing.T) (clientv3.Config, *clientv3.Client) { + cfg := clientv3.Config{ + Endpoints: []string{"http://127.0.0.1:2379"}, + } + c, _ := clientv3.New(cfg) + _, err := c.Delete(context.Background(), "/", clientv3.WithPrefix()) + require.NoError(t, err) + return cfg, c +} From 8dd66c39acb2d1f0edddcf90d8ebecc2831a9d42 Mon Sep 17 00:00:00 2001 From: John Warren Date: Wed, 13 Oct 2021 12:59:35 -0400 Subject: [PATCH 14/14] Stopping point --- rules/lock/map_locker.go | 42 +++++++---- rules/lock/map_locker_test.go | 133 ++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 15 deletions(-) create mode 100644 rules/lock/map_locker_test.go diff --git a/rules/lock/map_locker.go b/rules/lock/map_locker.go index cf5f80e..1f559fb 100644 --- a/rules/lock/map_locker.go +++ b/rules/lock/map_locker.go @@ -50,17 +50,6 @@ func (ml mapLocker) toggle(key string, lock bool) bool { return out } -func (ml mapLocker) Lock(key string) (RuleLock, error) { - ok := ml.toggle(key, true) - if !ok { - return nil, ErrLockedLocally - } - return mapLock{ - locker: ml, - key: key, - }, nil -} - func newMapLocker() mapLocker { locker := mapLocker{ stopCh: make(chan struct{}), @@ -99,12 +88,35 @@ func newMapLocker() mapLocker { return locker } -type mapLock struct { - locker mapLocker +type toggleLocker interface { + toggle(key string, lock bool) bool + close() +} +type toggleLockerAdapter struct { + toggle func(key string, lock bool) bool + close func() + errLocked error +} + +func (tla toggleLockerAdapter) Lock(key string) (RuleLock, error) { + ok := tla.toggle(key, true) + if !ok { + return nil, tla.errLocked + } + return toggleLock{ + toggle: tla.toggle, + close: tla.close, + key: key, + }, nil +} + +type toggleLock struct { + toggle func(key string, lock bool) bool + close func() key string } -func (ml mapLock) Unlock() error { - _ = ml.locker.toggle(ml.key, false) +func (tl toggleLock) Unlock() error { + _ = tl.toggle(tl.key, false) return nil } diff --git a/rules/lock/map_locker_test.go b/rules/lock/map_locker_test.go new file mode 100644 index 0000000..1f8c1d9 --- /dev/null +++ b/rules/lock/map_locker_test.go @@ -0,0 +1,133 @@ +package lock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_mapLocker_toggle(t *testing.T) { + testCases := []struct { + name string + + setup func(ml *mapLocker) + + key string + lock bool + + ok bool + }{ + { + name: "get_available", + key: "/foo", + setup: func(ml *mapLocker) { + ml.toggle("/bar", true) + }, + lock: true, + ok: true, + }, + { + name: "get_unavailable", + key: "/foo", + setup: func(ml *mapLocker) { + ml.toggle("/foo", true) + }, + lock: true, + ok: false, + }, + { + name: "release_existing", + key: "/foo", + setup: func(ml *mapLocker) { + ml.toggle("/foo", true) + }, + lock: false, + ok: true, + }, + { + name: "release_nonexistent", + key: "/foo", + lock: false, + ok: true, + }, + { + name: "get_from_closed", + key: "/foo", + setup: func(ml *mapLocker) { + ml.close() + }, + lock: true, + ok: false, + }, + { + name: "release_from_closed", + key: "/foo", + setup: func(ml *mapLocker) { + ml.close() + }, + lock: false, + ok: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ml := newMapLocker() + defer ml.close() + + if tc.setup != nil { + tc.setup(&ml) + } + + assert.Equal(t, tc.ok, ml.toggle(tc.key, tc.lock)) + + }) + } +} + +func Test_toggleLockAdapter(t *testing.T) { + const ( + testKey = "/foo" + ) + errLocked := errors.New("locked") + testCases := []struct { + name string + + lock bool + toggleOk bool + + err error + }{ + { + name: "success", + toggleOk: true, + }, + { + name: "failure", + toggleOk: false, + err: errLocked, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expectedLock := true + var err error + tla := toggleLockerAdapter{ + toggle: func(key string, lock bool) bool { + assert.Equal(t, expectedLock, lock) + assert.Equal(t, testKey, key) + return tc.toggleOk + }, + errLocked: errLocked, + } + lock, err := tla.Lock(testKey) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + return + } + assert.NoError(t, err) + expectedLock = false + _ = assert.NotNil(t, lock) && assert.NoError(t, lock.Unlock()) + }) + } +}