Skip to content

Commit

Permalink
Merge pull request #258 from trozet/protect_cache_read
Browse files Browse the repository at this point in the history
Protect cache integrity during reads
  • Loading branch information
dave-tucker authored Nov 5, 2021
2 parents f493ff7 + 91770f8 commit 34a572b
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 47 deletions.
11 changes: 6 additions & 5 deletions client/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -18,7 +19,7 @@ type API interface {
// The function parameter must be a pointer to a slice of Models
// If the slice is null, the entire cache will be copied into the slice
// If it has a capacity != 0, only 'capacity' elements will be filled in
List(result interface{}) error
List(ctx context.Context, result interface{}) error

// Create a Conditional API from a Function that is used to filter cached data
// The function must accept a Model implementation and return a boolean. E.g:
Expand All @@ -40,7 +41,7 @@ type API interface {
// provided model and the indexes defined in the associated schema
// For more complex ways of searching for elements in the cache, the
// preferred way is Where({condition}).List()
Get(model.Model) error
Get(context.Context, model.Model) error

// Create returns the operation needed to add the model(s) to the Database
// Only fields with non-default values will be added to the transaction
Expand All @@ -53,7 +54,7 @@ type API interface {
type ConditionalAPI interface {
// List uses the condition to search on the cache and populates
// the slice of Models objects based on their type
List(result interface{}) error
List(ctx context.Context, result interface{}) error

// Mutate returns the operations needed to perform the mutation specified
// By the model and the list of Mutation objects
Expand Down Expand Up @@ -93,7 +94,7 @@ type api struct {
}

// List populates a slice of Models given as parameter based on the configured Condition
func (a api) List(result interface{}) error {
func (a api) List(ctx context.Context, result interface{}) error {
resultPtr := reflect.ValueOf(result)
if resultPtr.Type().Kind() != reflect.Ptr {
return &ErrWrongType{resultPtr.Type(), "Expected pointer to slice of valid Models"}
Expand Down Expand Up @@ -206,7 +207,7 @@ func (a api) conditionFromModel(any bool, model model.Model, cond ...model.Condi
//
// The way the cache is searched depends on the fields already populated in 'result'
// Any table index (including _uuid) will be used for comparison
func (a api) Get(m model.Model) error {
func (a api) Get(ctx context.Context, m model.Model) error {
table, err := a.getTableFromModel(m)
if err != nil {
return err
Expand Down
17 changes: 9 additions & 8 deletions client/api_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -95,7 +96,7 @@ func TestAPIListSimple(t *testing.T) {
result = make([]testLogicalSwitch, tt.initialCap)
}
api := newAPI(tcache, &discardLogger)
err := api.List(&result)
err := api.List(context.Background(), &result)
if tt.err {
assert.NotNil(t, err)
} else {
Expand All @@ -111,22 +112,22 @@ func TestAPIListSimple(t *testing.T) {
t.Run("ApiList: Error wrong type", func(t *testing.T) {
var result []string
api := newAPI(tcache, &discardLogger)
err := api.List(&result)
err := api.List(context.Background(), &result)
assert.NotNil(t, err)
})

t.Run("ApiList: Type Selection", func(t *testing.T) {
var result []testLogicalSwitchPort
api := newAPI(tcache, &discardLogger)
err := api.List(&result)
err := api.List(context.Background(), &result)
assert.Nil(t, err)
assert.Len(t, result, 0, "Should be empty since cache is empty")
})

t.Run("ApiList: Empty List", func(t *testing.T) {
result := []testLogicalSwitch{}
api := newAPI(tcache, &discardLogger)
err := api.List(&result)
err := api.List(context.Background(), &result)
assert.Nil(t, err)
assert.Len(t, result, len(lscacheList))
})
Expand Down Expand Up @@ -213,7 +214,7 @@ func TestAPIListPredicate(t *testing.T) {
var result []testLogicalSwitch
api := newAPI(tcache, &discardLogger)
cond := api.WhereCache(tt.predicate)
err := cond.List(&result)
err := cond.List(context.Background(), &result)
if tt.err {
assert.NotNil(t, err)
} else {
Expand Down Expand Up @@ -301,7 +302,7 @@ func TestAPIListFields(t *testing.T) {
// Clean object
testObj = testLogicalSwitchPort{}
api := newAPI(tcache, &discardLogger)
err := api.Where(&testObj).List(&result)
err := api.Where(&testObj).List(context.Background(), &result)
if tt.err {
assert.NotNil(t, err)
} else {
Expand All @@ -319,7 +320,7 @@ func TestAPIListFields(t *testing.T) {
UUID: aUUID0,
}

err := api.Where(&obj).List(&result)
err := api.Where(&obj).List(context.Background(), &result)
assert.NotNil(t, err)
})
}
Expand Down Expand Up @@ -502,7 +503,7 @@ func TestAPIGet(t *testing.T) {
var result testLogicalSwitchPort
tt.prepare(&result)
api := newAPI(tcache, &discardLogger)
err := api.Get(&result)
err := api.Get(context.Background(), &result)
if tt.err {
assert.NotNil(t, err)
} else {
Expand Down
56 changes: 52 additions & 4 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"reflect"
"strings"
"sync"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/cenkalti/rpc2"
Expand Down Expand Up @@ -1095,13 +1096,56 @@ func (o *ovsdbClient) Close() {
o.rpcClient.Close()
}

// Ensures the cache is consistent by evaluating that the client is connected
// and the monitor is fully setup, with the cache populated
func isCacheConsistent(db *database) bool {
// This works because when a client is disconnected the deferUpdates variable
// will be set to true. deferUpdates is also protected by the db.cacheMutex.
// When the client reconnects and then re-establishes the monitor; the final step
// is to process all deferred updates, set deferUpdates back to false, and unlock cacheMutex
db.cacheMutex.RLock()
defer db.cacheMutex.RUnlock()
return !db.deferUpdates
}

// best effort to ensure cache is in a good state for reading
func waitForCacheConsistent(ctx context.Context, db *database, logger *logr.Logger, dbName string) {
if !hasMonitors(db) {
return
}
ticker := time.NewTicker(50 * time.Millisecond)
for {
select {
case <-ctx.Done():
logger.V(3).Info("warning: unable to ensure cache consistency for reading",
"database", dbName)
return
case <-ticker.C:
if isCacheConsistent(db) {
return
}

}
}
}

func hasMonitors(db *database) bool {
db.monitorsMutex.Lock()
defer db.monitorsMutex.Unlock()
return len(db.monitors) > 0
}

// Client API interface wrapper functions
// We add this wrapper to allow users to access the API directly on the
// client object

//Get implements the API interface's Get function
func (o *ovsdbClient) Get(model model.Model) error {
return o.primaryDB().api.Get(model)
func (o *ovsdbClient) Get(ctx context.Context, model model.Model) error {
primaryDB := o.primaryDB()
waitForCacheConsistent(ctx, primaryDB, o.logger, o.primaryDBName)
primaryDB.cacheMutex.RLock()
defer primaryDB.cacheMutex.RUnlock()
return primaryDB.api.Get(ctx, model)
}

//Create implements the API interface's Create function
Expand All @@ -1110,8 +1154,12 @@ func (o *ovsdbClient) Create(models ...model.Model) ([]ovsdb.Operation, error) {
}

//List implements the API interface's List function
func (o *ovsdbClient) List(result interface{}) error {
return o.primaryDB().api.List(result)
func (o *ovsdbClient) List(ctx context.Context, result interface{}) error {
primaryDB := o.primaryDB()
waitForCacheConsistent(ctx, primaryDB, o.logger, o.primaryDBName)
primaryDB.cacheMutex.RLock()
defer primaryDB.cacheMutex.RUnlock()
return primaryDB.api.List(ctx, result)
}

//Where implements the API interface's Where function
Expand Down
2 changes: 1 addition & 1 deletion cmd/stress/stress.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func cleanup(ctx context.Context) {

// Remove all existing bridges
var bridges []bridgeType
if err := ovs.List(&bridges); err == nil {
if err := ovs.List(context.Background(), &bridges); err == nil {
log.Printf("%d existing bridges found", len(bridges))
for _, bridge := range bridges {
deleteBridge(ctx, ovs, rootUUID, &bridge)
Expand Down
2 changes: 1 addition & 1 deletion example/play_with_ovs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func play(ovs client.Client) {
} else {
fmt.Printf("Current list of bridges:\n")
var bridges []vswitchd.Bridge
if err := ovs.List(&bridges); err != nil {
if err := ovs.List(context.Background(), &bridges); err != nil {
log.Fatal(err)
}
for _, b := range bridges {
Expand Down
14 changes: 7 additions & 7 deletions server/server_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ func TestClientServerInsert(t *testing.T) {
uuid := reply[0].UUID.GoUUID
require.Eventually(t, func() bool {
br := &bridgeType{UUID: uuid}
err := ovs.Get(br)
err := ovs.Get(context.Background(), br)
return err == nil
}, 2*time.Second, 500*time.Millisecond)

br := &bridgeType{UUID: uuid}
err = ovs.Get(br)
err = ovs.Get(context.Background(), br)
require.NoError(t, err)

assert.Equal(t, bridgeRow.Name, br.Name)
Expand Down Expand Up @@ -335,7 +335,7 @@ func TestClientServerInsertAndDelete(t *testing.T) {
uuid := reply[0].UUID.GoUUID
assert.Eventually(t, func() bool {
br := &bridgeType{UUID: uuid}
err := ovs.Get(br)
err := ovs.Get(context.Background(), br)
return err == nil
}, 2*time.Second, 500*time.Millisecond)

Expand Down Expand Up @@ -457,7 +457,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) {
uuid := reply[0].UUID.GoUUID
assert.Eventually(t, func() bool {
br := &bridgeType{UUID: uuid}
err := ovs.Get(br)
err := ovs.Get(context.Background(), br)
return err == nil
}, 2*time.Second, 500*time.Millisecond)

Expand All @@ -482,7 +482,7 @@ func TestClientServerInsertAndUpdate(t *testing.T) {

require.Eventually(t, func() bool {
br := &bridgeType{UUID: uuid}
err = ovs.Get(br)
err = ovs.Get(context.Background(), br)
if err != nil {
return false
}
Expand All @@ -500,15 +500,15 @@ func TestClientServerInsertAndUpdate(t *testing.T) {

assert.Eventually(t, func() bool {
br := &bridgeType{UUID: uuid}
err = ovs.Get(br)
err = ovs.Get(context.Background(), br)
if err != nil {
return false
}
return reflect.DeepEqual(br.ExternalIds, bridgeRow.ExternalIds)
}, 2*time.Second, 500*time.Millisecond)

br := &bridgeType{UUID: uuid}
err = ovs.Get(br)
err = ovs.Get(context.Background(), br)
assert.NoError(t, err)

assert.Equal(t, bridgeRow, br)
Expand Down
Loading

0 comments on commit 34a572b

Please sign in to comment.