diff --git a/clientv3/kv.go b/clientv3/kv.go index 0197c4840ad3..418f6c32c1c5 100644 --- a/clientv3/kv.go +++ b/clientv3/kv.go @@ -74,9 +74,15 @@ func (op OpResponse) Get() *GetResponse { return op.get } func (op OpResponse) Del() *DeleteResponse { return op.del } func (op OpResponse) Txn() *TxnResponse { return op.txn } +func (resp *PutResponse) ToOpResponse() OpResponse { + return OpResponse{put: resp} +} func (resp *GetResponse) ToOpResponse() OpResponse { return OpResponse{get: resp} } +func (resp *DeleteResponse) ToOpResponse() OpResponse { + return OpResponse{del: resp} +} func (resp *TxnResponse) ToOpResponse() OpResponse { return OpResponse{txn: resp} } diff --git a/clientv3/ordering/kv.go b/clientv3/ordering/kv.go index ec2b7384f66c..8e4d848133d5 100644 --- a/clientv3/ordering/kv.go +++ b/clientv3/ordering/kv.go @@ -68,7 +68,10 @@ func (kv *kvOrdering) Get(ctx context.Context, key string, opts ...clientv3.OpOp kv.setPrevRev(resp.Header.Revision) return resp, nil } - kv.orderViolationFunc(op, r, prevRev) + err = kv.orderViolationFunc(op, r, prevRev) + if err != nil { + return nil, err + } } } @@ -99,23 +102,26 @@ type txnOrdering struct { func (txn *txnOrdering) If(cs ...clientv3.Cmp) clientv3.Txn { txn.mu.Lock() + defer txn.mu.Unlock() txn.cmps = cs - txn.mu.Unlock() - return txn.Txn.If(cs...) + txn.Txn.If(cs...) + return txn } func (txn *txnOrdering) Then(ops ...clientv3.Op) clientv3.Txn { txn.mu.Lock() + defer txn.mu.Unlock() txn.thenOps = ops - txn.mu.Unlock() - return txn.Txn.Then(ops...) + txn.Txn.Then(ops...) + return txn } func (txn *txnOrdering) Else(ops ...clientv3.Op) clientv3.Txn { txn.mu.Lock() + defer txn.mu.Unlock() txn.elseOps = ops - txn.mu.Unlock() - return txn.Txn.Else(ops...) + txn.Txn.Else(ops...) + return txn } func (txn *txnOrdering) Commit() (*clientv3.TxnResponse, error) { @@ -124,8 +130,8 @@ func (txn *txnOrdering) Commit() (*clientv3.TxnResponse, error) { // access to txnOrdering could change the prevRev field in the // middle of the Commit operation. prevRev := txn.getPrevRev() + opTxn := clientv3.OpTxn(txn.cmps, txn.thenOps, txn.elseOps) for { - opTxn := clientv3.OpTxn(txn.cmps, txn.thenOps, txn.elseOps) opResp, err := txn.KV.Do(txn.ctx, opTxn) if err != nil { return nil, err @@ -135,6 +141,9 @@ func (txn *txnOrdering) Commit() (*clientv3.TxnResponse, error) { txn.setPrevRev(txnResp.Header.Revision) return txnResp, nil } - txn.orderViolationFunc(opTxn, opResp, prevRev) + err = txn.orderViolationFunc(opTxn, opResp, prevRev) + if err != nil { + return nil, err + } } } diff --git a/clientv3/ordering/kv_test.go b/clientv3/ordering/kv_test.go index 4d6df13e70ad..b367d91fca81 100644 --- a/clientv3/ordering/kv_test.go +++ b/clientv3/ordering/kv_test.go @@ -15,20 +15,160 @@ package ordering import ( + "context" + "errors" "sync" "testing" + "time" "github.com/coreos/etcd/clientv3" pb "github.com/coreos/etcd/etcdserver/etcdserverpb" - "golang.org/x/net/context" + "github.com/coreos/etcd/integration" + "github.com/coreos/etcd/pkg/testutil" + gContext "golang.org/x/net/context" ) +func TestDetectKvOrderViolation(t *testing.T) { + var errOrderViolation = errors.New("Detected Order Violation") + + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + // ensure that the second member has the current revision for the key foo + cli.SetEndpoints(clus.Members[1].GRPCAddr()) + _, err = cli.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // stop third member in order to force the member to have an outdated revision + clus.Members[2].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // perform get request against the first member, in order to + // set up kvOrdering to expect "foo" revisions greater than that of + // the third member. + orderingKv := NewKV(cli.KV, + func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + return errOrderViolation + }) + _, err = orderingKv.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // ensure that only the third member is queried during requests + clus.Members[0].Stop(t) + clus.Members[1].Stop(t) + clus.Members[2].Restart(t) + // force OrderingKv to query the third member + cli.SetEndpoints(clus.Members[2].GRPCAddr()) + + _, err = orderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + if err != errOrderViolation { + t.Fatalf("expected %v, got %v", errOrderViolation, err) + } +} + +func TestDetectTxnOrderViolation(t *testing.T) { + var errOrderViolation = errors.New("Detected Order Violation") + + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + // ensure that the second member has the current revision for the key foo + cli.SetEndpoints(clus.Members[1].GRPCAddr()) + _, err = cli.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // stop third member in order to force the member to have an outdated revision + clus.Members[2].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // perform get request against the first member, in order to + // set up kvOrdering to expect "foo" revisions greater than that of + // the third member. + orderingKv := NewKV(cli.KV, + func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + return errOrderViolation + }) + orderingTxn := orderingKv.Txn(ctx) + _, err = orderingTxn.If( + clientv3.Compare(clientv3.Value("b"), ">", "a"), + ).Then( + clientv3.OpGet("foo"), + ).Commit() + if err != nil { + t.Fatal(err) + } + + // ensure that only the third member is queried during requests + clus.Members[0].Stop(t) + clus.Members[1].Stop(t) + clus.Members[2].Restart(t) + // force OrderingKv to query the third member + cli.SetEndpoints(clus.Members[2].GRPCAddr()) + + _, err = orderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + orderingTxn = orderingKv.Txn(ctx) + _, err = orderingTxn.If( + clientv3.Compare(clientv3.Value("b"), ">", "a"), + ).Then( + clientv3.OpGet("foo", clientv3.WithSerializable()), + ).Commit() + if err != errOrderViolation { + t.Fatalf("expected %v, got %v", errOrderViolation, err) + } +} + type mockKV struct { clientv3.KV response clientv3.OpResponse } -func (kv *mockKV) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { +func (kv *mockKV) Do(ctx gContext.Context, op clientv3.Op) (clientv3.OpResponse, error) { return kv.response, nil } @@ -68,8 +208,9 @@ func TestKvOrdering(t *testing.T) { kv := &kvOrdering{ mKV, func(r *clientv3.GetResponse) OrderViolationFunc { - return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) { + return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { r.Header.Revision++ + return nil } }(tt.response), tt.prevRev, @@ -120,8 +261,9 @@ func TestTxnOrdering(t *testing.T) { kv := &kvOrdering{ mKV, func(r *clientv3.TxnResponse) OrderViolationFunc { - return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) { + return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { r.Header.Revision++ + return nil } }(tt.response), tt.prevRev, diff --git a/clientv3/ordering/util.go b/clientv3/ordering/util.go index c71b6ddd4417..7b151e78e592 100644 --- a/clientv3/ordering/util.go +++ b/clientv3/ordering/util.go @@ -15,24 +15,34 @@ package ordering import ( + "errors" "sync" + "time" "github.com/coreos/etcd/clientv3" ) -type OrderViolationFunc func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) +type OrderViolationFunc func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error + +var ErrNoGreaterRev = errors.New("etcdclient: no cluster members have a revision higher than the previously received revision") func NewOrderViolationSwitchEndpointClosure(c clientv3.Client) OrderViolationFunc { - // May need a c.Sync(c.Ctx()) to make sure that the list of endpoints is up to date var mu sync.Mutex violationCount := 0 - return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) { + return func(op clientv3.Op, resp clientv3.OpResponse, prevRev int64) error { + if violationCount > len(c.Endpoints()) { + return ErrNoGreaterRev + } mu.Lock() defer mu.Unlock() eps := c.Endpoints() - // force client to connect to the specificied endpoint by limiting to single endpoint + // force client to connect to the specificied endpoint by limiting to a single endpoint c.SetEndpoints(eps[violationCount%len(eps)]) + time.Sleep(1 * time.Second) // give enough time for operation + // set available endpoints back to all endpoints in order to enure + // that the client has access to all the endpoints. c.SetEndpoints(eps...) violationCount++ + return nil } } diff --git a/clientv3/ordering/util_test.go b/clientv3/ordering/util_test.go new file mode 100644 index 000000000000..37e91b5bed71 --- /dev/null +++ b/clientv3/ordering/util_test.go @@ -0,0 +1,145 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ordering + +import ( + "context" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/integration" + "github.com/coreos/etcd/pkg/testutil" +) + +func TestEndpointSwitchResolvesViolation(t *testing.T) { + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + eps := cli.Endpoints() + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + // ensure that the second member has current revision for key "foo" + cli.SetEndpoints(clus.Members[1].GRPCAddr()) + _, err = cli.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + // create partition between third members and the first two members + // in order to guarantee that the third member's revision of "foo" + // falls behind as updates to "foo" are issued to the first two members. + clus.Members[2].InjectPartition(t, clus.Members[:2]) + time.Sleep(1 * time.Second) // give enough time for the operation + + // update to "foo" will not be replicated to the third member due to the partition + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // reset client endpoints to all members such that the copy of cli sent to + // NewOrderViolationSwitchEndpointClosure will be able to + // access the full list of endpoints. + cli.SetEndpoints(eps...) + OrderingKv := NewKV(cli.KV, NewOrderViolationSwitchEndpointClosure(*cli)) + // set prevRev to the second member's revision of "foo" such that + // the revision is higher than the third member's revision of "foo" + _, err = OrderingKv.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + cli.SetEndpoints(clus.Members[2].GRPCAddr()) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = OrderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + if err != nil { + t.Fatalf("failed to resolve order violation %v", err) + } +} + +func TestUnresolvableOrderViolation(t *testing.T) { + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 5}) + defer clus.Terminate(t) + cfg := clientv3.Config{ + Endpoints: []string{ + clus.Members[0].GRPCAddr(), + clus.Members[1].GRPCAddr(), + clus.Members[2].GRPCAddr(), + clus.Members[3].GRPCAddr(), + clus.Members[4].GRPCAddr(), + }, + } + cli, err := clientv3.New(cfg) + eps := cli.Endpoints() + ctx := context.TODO() + + cli.SetEndpoints(clus.Members[0].GRPCAddr()) + time.Sleep(1 * time.Second) + _, err = cli.Put(ctx, "foo", "bar") + if err != nil { + t.Fatal(err) + } + + // stop fourth member in order to force the member to have an outdated revision + clus.Members[3].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + // stop fifth member in order to force the member to have an outdated revision + clus.Members[4].Stop(t) + time.Sleep(1 * time.Second) // give enough time for operation + _, err = cli.Put(ctx, "foo", "buzz") + if err != nil { + t.Fatal(err) + } + + // reset client endpoints to all members such that the copy of cli sent to + // NewOrderViolationSwitchEndpointClosure will be able to + // access the full list of endpoints. + cli.SetEndpoints(eps...) + OrderingKv := NewKV(cli.KV, NewOrderViolationSwitchEndpointClosure(*cli)) + // set prevRev to the first member's revision of "foo" such that + // the revision is higher than the fourth and fifth members' revision of "foo" + _, err = OrderingKv.Get(ctx, "foo") + if err != nil { + t.Fatal(err) + } + + clus.Members[0].Stop(t) + clus.Members[1].Stop(t) + clus.Members[2].Stop(t) + clus.Members[3].Restart(t) + clus.Members[4].Restart(t) + cli.SetEndpoints(clus.Members[3].GRPCAddr()) + time.Sleep(1 * time.Second) // give enough time for operation + + _, err = OrderingKv.Get(ctx, "foo", clientv3.WithSerializable()) + if err != ErrNoGreaterRev { + t.Fatalf("expected %v, got %v", ErrNoGreaterRev, err) + } +}