diff --git a/clientv3/kv.go b/clientv3/kv.go index f887e044102..418f6c32c1c 100644 --- a/clientv3/kv.go +++ b/clientv3/kv.go @@ -74,6 +74,19 @@ func (op OpResponse) Get() *GetResponse { return op.get } func (op OpResponse) Del() *DeleteResponse { return op.del } func (op OpResponse) Txn() *TxnResponse { return op.txn } +func (resp *PutResponse) ToOpResponse() OpResponse { + return OpResponse{put: resp} +} +func (resp *GetResponse) ToOpResponse() OpResponse { + return OpResponse{get: resp} +} +func (resp *DeleteResponse) ToOpResponse() OpResponse { + return OpResponse{del: resp} +} +func (resp *TxnResponse) ToOpResponse() OpResponse { + return OpResponse{txn: resp} +} + type kv struct { remote pb.KVClient } diff --git a/clientv3/ordering/kv.go b/clientv3/ordering/kv.go new file mode 100644 index 00000000000..e8bf07b8c74 --- /dev/null +++ b/clientv3/ordering/kv.go @@ -0,0 +1,148 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ordering + +import ( + "context" + "sync" + + "github.com/coreos/etcd/clientv3" +) + +// kvOrdering ensures that serialized requests do not return +// get with revisions less than the previous +// returned revision. +type kvOrdering struct { + clientv3.KV + orderViolationFunc OrderViolationFunc + prevRev int64 + revMu sync.RWMutex +} + +func NewKV(kv clientv3.KV, orderViolationFunc OrderViolationFunc) *kvOrdering { + return &kvOrdering{kv, orderViolationFunc, 0, sync.RWMutex{}} +} + +func (kv *kvOrdering) getPrevRev() int64 { + kv.revMu.RLock() + defer kv.revMu.RUnlock() + return kv.prevRev +} + +func (kv *kvOrdering) setPrevRev(currRev int64) { + prevRev := kv.getPrevRev() + kv.revMu.Lock() + defer kv.revMu.Unlock() + if currRev > prevRev { + kv.prevRev = currRev + } +} + +func (kv *kvOrdering) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + // prevRev is stored in a local variable in order to record the prevRev + // at the beginning of the Get operation, because concurrent + // access to kvOrdering could change the prevRev field in the + // middle of the Get operation. + prevRev := kv.getPrevRev() + op := clientv3.OpGet(key, opts...) + for { + r, err := kv.KV.Do(ctx, op) + if err != nil { + return nil, err + } + resp := r.Get() + if resp.Header.Revision >= prevRev { + kv.setPrevRev(resp.Header.Revision) + return resp, nil + } + err = kv.orderViolationFunc(op, r, prevRev) + if err != nil { + return nil, err + } + } +} + +func (kv *kvOrdering) Txn(ctx context.Context) clientv3.Txn { + return &txnOrdering{ + kv.KV.Txn(ctx), + kv, + ctx, + sync.Mutex{}, + []clientv3.Cmp{}, + []clientv3.Op{}, + []clientv3.Op{}, + } +} + +// txnOrdering ensures that serialized requests do not return +// txn responses with revisions less than the previous +// returned revision. +type txnOrdering struct { + clientv3.Txn + *kvOrdering + ctx context.Context + mu sync.Mutex + cmps []clientv3.Cmp + thenOps []clientv3.Op + elseOps []clientv3.Op +} + +func (txn *txnOrdering) If(cs ...clientv3.Cmp) clientv3.Txn { + txn.mu.Lock() + defer txn.mu.Unlock() + txn.cmps = cs + txn.Txn.If(cs...) + return txn +} + +func (txn *txnOrdering) Then(ops ...clientv3.Op) clientv3.Txn { + txn.mu.Lock() + defer txn.mu.Unlock() + txn.thenOps = ops + txn.Txn.Then(ops...) + return txn +} + +func (txn *txnOrdering) Else(ops ...clientv3.Op) clientv3.Txn { + txn.mu.Lock() + defer txn.mu.Unlock() + txn.elseOps = ops + txn.Txn.Else(ops...) + return txn +} + +func (txn *txnOrdering) Commit() (*clientv3.TxnResponse, error) { + // prevRev is stored in a local variable in order to record the prevRev + // at the beginning of the Commit operation, because concurrent + // access to txnOrdering could change the prevRev field in the + // middle of the Commit operation. + prevRev := txn.getPrevRev() + opTxn := clientv3.OpTxn(txn.cmps, txn.thenOps, txn.elseOps) + for { + opResp, err := txn.KV.Do(txn.ctx, opTxn) + if err != nil { + return nil, err + } + txnResp := opResp.Txn() + if txnResp.Header.Revision >= prevRev { + txn.setPrevRev(txnResp.Header.Revision) + return txnResp, nil + } + err = txn.orderViolationFunc(opTxn, opResp, prevRev) + if err != nil { + return nil, err + } + } +} diff --git a/clientv3/ordering/kv_test.go b/clientv3/ordering/kv_test.go new file mode 100644 index 00000000000..b923cab3de4 --- /dev/null +++ b/clientv3/ordering/kv_test.go @@ -0,0 +1,290 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ordering + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + pb "github.com/coreos/etcd/etcdserver/etcdserverpb" + "github.com/coreos/etcd/integration" + "github.com/coreos/etcd/pkg/testutil" + gContext "golang.org/x/net/context" +) + +func TestDetectKvOrderViolation(t *testing.T) { + var errOrderViolation = errors.New("Detected Order Violation") + + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + // ensure that the second member has the current revision for the key foo + cli.SetEndpoints(clus.Members[1].GRPCAddr()) + _, err = cli.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // stop third member in order to force the member to have an outdated revision + clus.Members[2].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // perform get request against the first member, in order to + // set up kvOrdering to expect "foo" revisions greater than that of + // the third member. + orderingKv := NewKV(cli.KV, + func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + return errOrderViolation + }) + _, err = orderingKv.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // ensure that only the third member is queried during requests + clus.Members[0].Stop(t) + clus.Members[1].Stop(t) + clus.Members[2].Restart(t) + // force OrderingKv to query the third member + cli.SetEndpoints(clus.Members[2].GRPCAddr()) + + _, err = orderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + if err != errOrderViolation { + t.Fatalf("expected %v, got %v", errOrderViolation, err) + } +} + +func TestDetectTxnOrderViolation(t *testing.T) { + var errOrderViolation = errors.New("Detected Order Violation") + + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + // ensure that the second member has the current revision for the key foo + cli.SetEndpoints(clus.Members[1].GRPCAddr()) + _, err = cli.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // stop third member in order to force the member to have an outdated revision + clus.Members[2].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // perform get request against the first member, in order to + // set up kvOrdering to expect "foo" revisions greater than that of + // the third member. + orderingKv := NewKV(cli.KV, + func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + return errOrderViolation + }) + orderingTxn := orderingKv.Txn(ctx) + _, err = orderingTxn.If( + clientv3.Compare(clientv3.Value("b"), ">", "a"), + ).Then( + clientv3.OpGet("foo"), + ).Commit() + if err != nil { + t.Fatal(err) + } + + // ensure that only the third member is queried during requests + clus.Members[0].Stop(t) + clus.Members[1].Stop(t) + clus.Members[2].Restart(t) + // force OrderingKv to query the third member + cli.SetEndpoints(clus.Members[2].GRPCAddr()) + + _, err = orderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + orderingTxn = orderingKv.Txn(ctx) + _, err = orderingTxn.If( + clientv3.Compare(clientv3.Value("b"), ">", "a"), + ).Then( + clientv3.OpGet("foo", clientv3.WithSerializable()), + ).Commit() + if err != errOrderViolation { + t.Fatalf("expected %v, got %v", errOrderViolation, err) + } +} + +type mockKV struct { + clientv3.KV + response clientv3.OpResponse +} + +func (kv *mockKV) Do(ctx gContext.Context, op clientv3.Op) (clientv3.OpResponse, error) { + return kv.response, nil +} + +var rangeTests = []struct { + prevRev int64 + response *clientv3.GetResponse +}{ + { + 5, + &clientv3.GetResponse{ + Header: &pb.ResponseHeader{ + Revision: 5, + }, + }, + }, + { + 5, + &clientv3.GetResponse{ + Header: &pb.ResponseHeader{ + Revision: 4, + }, + }, + }, + { + 5, + &clientv3.GetResponse{ + Header: &pb.ResponseHeader{ + Revision: 6, + }, + }, + }, +} + +func TestKvOrdering(t *testing.T) { + for i, tt := range rangeTests { + mKV := &mockKV{clientv3.NewKVFromKVClient(nil), tt.response.ToOpResponse()} + kv := &kvOrdering{ + mKV, + func(r *clientv3.GetResponse) OrderViolationFunc { + return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + r.Header.Revision++ + return nil + } + }(tt.response), + tt.prevRev, + sync.RWMutex{}, + } + res, err := kv.Get(nil, "mockKey") + if err != nil { + t.Errorf("#%d: expected response %+v, got error %+v", i, tt.response, err) + } + if rev := res.Header.Revision; rev < tt.prevRev { + t.Errorf("#%d: expected revision %d, got %d", i, tt.prevRev, rev) + } + } +} + +var txnTests = []struct { + prevRev int64 + response *clientv3.TxnResponse +}{ + { + 5, + &clientv3.TxnResponse{ + Header: &pb.ResponseHeader{ + Revision: 5, + }, + }, + }, + { + 5, + &clientv3.TxnResponse{ + Header: &pb.ResponseHeader{ + Revision: 8, + }, + }, + }, + { + 5, + &clientv3.TxnResponse{ + Header: &pb.ResponseHeader{ + Revision: 4, + }, + }, + }, +} + +func TestTxnOrdering(t *testing.T) { + for i, tt := range txnTests { + mKV := &mockKV{clientv3.NewKVFromKVClient(nil), tt.response.ToOpResponse()} + kv := &kvOrdering{ + mKV, + func(r *clientv3.TxnResponse) OrderViolationFunc { + return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + r.Header.Revision++ + return nil + } + }(tt.response), + tt.prevRev, + sync.RWMutex{}, + } + txn := &txnOrdering{ + kv.Txn(context.Background()), + kv, + context.Background(), + sync.Mutex{}, + []clientv3.Cmp{}, + []clientv3.Op{}, + []clientv3.Op{}, + } + res, err := txn.Commit() + if err != nil { + t.Errorf("#%d: expected response %+v, got error %+v", i, tt.response, err) + } + if rev := res.Header.Revision; rev < tt.prevRev { + t.Errorf("#%d: expected revision %d, got %d", i, tt.prevRev, rev) + } + } +} diff --git a/clientv3/ordering/util.go b/clientv3/ordering/util.go new file mode 100644 index 00000000000..7b151e78e59 --- /dev/null +++ b/clientv3/ordering/util.go @@ -0,0 +1,48 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ordering + +import ( + "errors" + "sync" + "time" + + "github.com/coreos/etcd/clientv3" +) + +type OrderViolationFunc func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error + +var ErrNoGreaterRev = errors.New("etcdclient: no cluster members have a revision higher than the previously received revision") + +func NewOrderViolationSwitchEndpointClosure(c clientv3.Client) OrderViolationFunc { + var mu sync.Mutex + violationCount := 0 + return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + if violationCount > len(c.Endpoints()) { + return ErrNoGreaterRev + } + mu.Lock() + defer mu.Unlock() + eps := c.Endpoints() + // force client to connect to the specificied endpoint by limiting to a single endpoint + c.SetEndpoints(eps[violationCount%len(eps)]) + time.Sleep(1 * time.Second) // give enough time for operation + // set available endpoints back to all endpoints in order to enure + // that the client has access to all the endpoints. + c.SetEndpoints(eps...) + violationCount++ + return nil + } +} diff --git a/clientv3/ordering/util_test.go b/clientv3/ordering/util_test.go new file mode 100644 index 00000000000..37e91b5bed7 --- /dev/null +++ b/clientv3/ordering/util_test.go @@ -0,0 +1,145 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ordering + +import ( + "context" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/integration" + "github.com/coreos/etcd/pkg/testutil" +) + +func TestEndpointSwitchResolvesViolation(t *testing.T) { + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + eps := cli.Endpoints() + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + // ensure that the second member has current revision for key "foo" + cli.SetEndpoints(clus.Members[1].GRPCAddr()) + _, err = cli.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // create partition between third members and the first two members + // in order to guarantee that the third member's revision of "foo" + // falls behind as updates to "foo" are issued to the first two members. + clus.Members[2].InjectPartition(t, clus.Members[:2]) + time.Sleep(1 * time.Second) // give enough time for the operation + + // update to "foo" will not be replicated to the third member due to the partition + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // reset client endpoints to all members such that the copy of cli sent to + // NewOrderViolationSwitchEndpointClosure will be able to + // access the full list of endpoints. + cli.SetEndpoints(eps...) + OrderingKv := NewKV(cli.KV, NewOrderViolationSwitchEndpointClosure(*cli)) + // set prevRev to the second member's revision of "foo" such that + // the revision is higher than the third member's revision of "foo" + _, err = OrderingKv.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + cli.SetEndpoints(clus.Members[2].GRPCAddr()) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = OrderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + if err != nil { + t.Fatalf("failed to resolve order violation %v", err) + } +} + +func TestUnresolvableOrderViolation(t *testing.T) { + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 5}) + defer clus.Terminate(t) + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + clus.Members[3].GRPCAddr(), + clus.Members[4].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + eps := cli.Endpoints() + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + time.Sleep(1 * time.Second) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + + // stop fourth member in order to force the member to have an outdated revision + clus.Members[3].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + // stop fifth member in order to force the member to have an outdated revision + clus.Members[4].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // reset client endpoints to all members such that the copy of cli sent to + // NewOrderViolationSwitchEndpointClosure will be able to + // access the full list of endpoints. + cli.SetEndpoints(eps...) + OrderingKv := NewKV(cli.KV, NewOrderViolationSwitchEndpointClosure(*cli)) + // set prevRev to the first member's revision of "foo" such that + // the revision is higher than the fourth and fifth members' revision of "foo" + _, err = OrderingKv.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + clus.Members[0].Stop(t) + clus.Members[1].Stop(t) + clus.Members[2].Stop(t) + clus.Members[3].Restart(t) + clus.Members[4].Restart(t) + cli.SetEndpoints(clus.Members[3].GRPCAddr()) + time.Sleep(1 * time.Second) // give enough time for operation + + _, err = OrderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + if err != ErrNoGreaterRev { + t.Fatalf("expected %v, got %v", ErrNoGreaterRev, err) + } +}