diff --git a/pkg/ccl/sqlproxyccl/denylist/BUILD.bazel b/pkg/ccl/sqlproxyccl/denylist/BUILD.bazel index b2d22c24e625..959b10660de5 100644 --- a/pkg/ccl/sqlproxyccl/denylist/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/denylist/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "denylist", srcs = [ + "file.go", "local_file.go", "mocks_generated.go", "service.go", @@ -17,15 +18,21 @@ go_library( "@com_github_cockroachdb_errors//:errors", "@com_github_golang_mock//gomock", "@com_github_spf13_viper//:viper", + "@in_gopkg_yaml_v2//:yaml_v2", ], ) go_test( name = "denylist_test", - srcs = ["local_file_test.go"], + srcs = [ + "file_test.go", + "local_file_test.go", + ], embed = [":denylist"], deps = [ "//pkg/util/leaktest", + "//pkg/util/timeutil", "@com_github_stretchr_testify//require", + "@in_gopkg_yaml_v2//:yaml_v2", ], ) diff --git a/pkg/ccl/sqlproxyccl/denylist/file.go b/pkg/ccl/sqlproxyccl/denylist/file.go new file mode 100644 index 000000000000..67bdb768ec37 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/denylist/file.go @@ -0,0 +1,185 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package denylist + +import ( + "context" + "io" + "os" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "gopkg.in/yaml.v2" +) + +const ( + defaultPollingInterval = time.Minute + defaultEmptyDenylistText = "SequenceNumber: 0" +) + +// File represents a on-disk version of the denylist config. +// This also serves as a spec of expected yaml file format. +type File struct { + Seq int64 `yaml:"SequenceNumber"` + Denylist []*DenyEntry `yaml:"denylist"` +} + +// Deserialize constructs a new DenylistFile from reader. +func Deserialize(reader io.Reader) (*File, error) { + decoder := yaml.NewDecoder(reader) + var denylistFile File + err := decoder.Decode(&denylistFile) + if err != nil { + return nil, err + } + return &denylistFile, nil +} + +// Serialize a File into raw bytes. +func (dlf *File) Serialize() ([]byte, error) { + return yaml.Marshal(dlf) +} + +// DenyEntry records info about one denied entity, +// the reason and the expiration time. +// This also serves as spec for the yaml config format. +type DenyEntry struct { + Entity DenyEntity `yaml:"entity"` + Expiration time.Time `yaml:"expiration"` + Reason string `yaml:"reason"` +} + +// Denylist represents an in-memory cache for the current denylist. +// It also handles the logic of deciding what to be denied. +type Denylist struct { + mu struct { + entries map[DenyEntity]*DenyEntry + *syncutil.RWMutex + } + pollingInterval time.Duration + timeSource timeutil.TimeSource + + ctx context.Context +} + +// NewDenylistWithFile returns a new denylist that automatically watches updates to a file. +// Note: this currently does not return an error. This is by design, since even if we trouble +// initiating a denylist with file, we can always update the file with correct content during +// runtime. We don't want sqlproxy fail to start just because there's something wrong with +// contents of a denylist file. +func NewDenylistWithFile(ctx context.Context, filename string, opts ...Option) *Denylist { + ret := &Denylist{ + pollingInterval: defaultPollingInterval, + timeSource: timeutil.DefaultTimeSource{}, + ctx: ctx, + } + ret.mu.entries = make(map[DenyEntity]*DenyEntry) + ret.mu.RWMutex = &syncutil.RWMutex{} + + for _, opt := range opts { + opt(ret) + } + err := ret.update(filename) + if err != nil { + // don't return just yet; sqlproxy should be able to carry on without a proper denylist + // and we still have a chance to recover. + // TODO(ye): add monitoring for failed updates; we don't want silent failures + log.Errorf(ctx, "error when reading from file %s: %v", filename, err) + } + + ret.watchForUpdate(filename) + + return ret +} + +// Option allows configuration of a denylist service. +type Option func(*Denylist) + +// WithPollingInterval specifies interval between polling for config file changes. +func WithPollingInterval(d time.Duration) Option { + return func(dl *Denylist) { + dl.pollingInterval = d + } +} + +// update the Denylist with content of the file. +func (dl *Denylist) update(filename string) error { + handle, err := os.Open(filename) + if err != nil { + log.Errorf(dl.ctx, "open file %s: %v", filename, err) + return err + } + defer handle.Close() + + dlf, err := Deserialize(handle) + if err != nil { + stat, _ := handle.Stat() + if stat != nil { + log.Errorf(dl.ctx, "error updating denylist from file %s modified at %s: %v", + filename, stat.ModTime(), err) + } else { + log.Errorf(dl.ctx, "error updating denylist from file %s: %v", + filename, err) + } + return err + } + dl.updateWithDenylistFile(dlf) + return nil +} + +func (dl *Denylist) updateWithDenylistFile(dlf *File) { + newEntries := make(map[DenyEntity]*DenyEntry) + for _, entry := range dlf.Denylist { + newEntries[entry.Entity] = entry + } + + dl.mu.Lock() + defer dl.mu.Unlock() + + dl.mu.entries = newEntries +} + +// Denied implements the Service interface. +func (dl *Denylist) Denied(entity DenyEntity) (*Entry, error) { + dl.mu.RLock() + defer dl.mu.RUnlock() + + if ent, ok := dl.mu.entries[entity]; ok && !ent.Expiration.Before(dl.timeSource.Now()) { + return &Entry{ent.Reason}, nil + } + return nil, nil +} + +// WatchForUpdates periodically reloads the denylist file. The daemon is +// canceled on ctx cancellation. +func (dl *Denylist) watchForUpdate(filename string) { + go func() { + // TODO(ye): use notification via SIGHUP instead. + // TODO(ye): use inotify or similar mechanism for watching file updates instead of polling. + t := timeutil.NewTimer() + defer t.Stop() + for { + t.Reset(dl.pollingInterval) + select { + case <-dl.ctx.Done(): + log.Errorf(dl.ctx, "WatchList daemon stopped: %v", dl.ctx.Err()) + return + case <-t.C: + t.Read = true + err := dl.update(filename) + if err != nil { + // TODO(ye): add monitoring for update failures. + continue + } + } + } + }() +} diff --git a/pkg/ccl/sqlproxyccl/denylist/file_test.go b/pkg/ccl/sqlproxyccl/denylist/file_test.go new file mode 100644 index 000000000000..628bfa924aa7 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/denylist/file_test.go @@ -0,0 +1,234 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package denylist + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "path/filepath" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +func TestDenyListFileParsing(t *testing.T) { + t.Run("test custom marshal code", func(t *testing.T) { + cases := []struct { + t Type + expected string + }{{ + IPAddrType, "ip", + }, { + ClusterType, "cluster", + }} + for _, tc := range cases { + s, err := tc.t.MarshalYAML() + require.NoError(t, err) + require.Equal(t, tc.expected, s) + } + }) + + t.Run("test custom unmarshal code", func(t *testing.T) { + cases := []struct { + raw string + expected Type + }{{ + "ip", IPAddrType, + }, { + "IP", IPAddrType, + }, + { + "Ip", IPAddrType, + }, + { + "Cluster", ClusterType, + }, + { + "cluster", ClusterType, + }, + { + "CLUSTER", ClusterType, + }, + { + "random text", UnknownType, + }, + } + for _, tc := range cases { + var parsed Type + err := yaml.Unmarshal([]byte(tc.raw), &parsed) + require.NoError(t, err) + require.Equal(t, tc.expected, parsed) + } + }) + + t.Run("end to end testing of file parsing", func(t *testing.T) { + defer leaktest.AfterTest(t)() + expirationTimeString := "2021-01-01T15:20:39Z" + expirationTime := time.Date(2021, 1, 1, 15, 20, 39, 0, time.UTC) + + emptyMap := make(map[DenyEntity]*DenyEntry) + + testCases := []struct { + input string + expected map[DenyEntity]*DenyEntry + }{ + {"text: ", emptyMap}, + {"random text\n\n\nmore random text", + emptyMap}, + {defaultEmptyDenylistText, emptyMap}, + {"SequenceNumber: 7", emptyMap}, + { + // old denylist format, making sure it won't break new denylist code + ` + SequenceNumber: 8 + 1.1.1.1: some reason + 61: another reason`, + emptyMap, + }, { + fmt.Sprintf(` + SequenceNumber: 9 + denylist: + - entity: {"item":"1.2.3.4", "type": "ip"} + expiration: %s + reason: over quota +`, expirationTimeString), + map[DenyEntity]*DenyEntry{{"1.2.3.4", IPAddrType}: { + DenyEntity{"1.2.3.4", IPAddrType}, + expirationTime, + "over quota", + }, + }}, + } + + // use cancel to prevent leaked goroutines from file watches + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tempDir := t.TempDir() + for i, tc := range testCases { + filename := filepath.Join(tempDir, fmt.Sprintf("denylist%d.yaml", i)) + require.NoError(t, ioutil.WriteFile(filename, []byte(tc.input), 0777)) + dl := NewDenylistWithFile(ctx, filename) + require.Equal(t, tc.expected, dl.mu.entries, "should return expected parsed file for %s", + tc.input) + } + }) + + t.Run("test Ser/De of File", func(t *testing.T) { + file := File{ + Seq: 72, + Denylist: []*DenyEntry{ + { + DenyEntity{"63", ClusterType}, + timeutil.Now(), + "over usage", + }, + { + DenyEntity{"8.8.8.8", IPAddrType}, + timeutil.Now().Add(1 * time.Hour), + "malicious IP", + }, + }, + } + + raw, err := file.Serialize() + require.NoError(t, err) + deserialized, err := Deserialize(bytes.NewBuffer(raw)) + require.NoError(t, err) + require.EqualValues(t, file, *deserialized) + }) +} + +func TestDenylistLogic(t *testing.T) { + defer leaktest.AfterTest(t)() + + startTime := time.Date(2021, 1, 1, 15, 20, 39, 0, time.UTC) + expirationTimeString := "2021-01-01T15:30:39Z" + futureTime := startTime.Add(time.Minute * 20) + + type denyIOSpec struct { + entity DenyEntity + outcome *Entry + } + + // This is a time evolution of a denylist + testCases := []struct { + input string + time time.Time + specs []denyIOSpec + }{ + { + fmt.Sprintf(` + SequenceNumber: 9 + denylist: + - entity: {"item": "1.2.3.4", "type": "IP"} + expiration: %s + reason: over quota`, expirationTimeString), + startTime.Add(10 * time.Second), + []denyIOSpec{ + {DenyEntity{"1.2.3.4", IPAddrType}, &Entry{"over quota"}}, + {DenyEntity{"61", ClusterType}, nil}, + {DenyEntity{"1.2.3.5", IPAddrType}, nil}, + }, + }, + { + fmt.Sprintf(` + SequenceNumber: 10 + denylist: + - entity: {"item": "1.2.3.4", "type": "IP"} + expiration: %s + reason: over quota + - entity: {"item": 61, "type": "Cluster"} + expiration: %s + reason: splunk pipeline`, expirationTimeString, expirationTimeString), + startTime.Add(20 * time.Second), + []denyIOSpec{ + {DenyEntity{"1.2.3.4", IPAddrType}, &Entry{"over quota"}}, + {DenyEntity{"61", ClusterType}, &Entry{"splunk pipeline"}}, + {DenyEntity{"1.2.3.5", IPAddrType}, nil}, + }}, + { + fmt.Sprintf(` + SequenceNumber: 11 + denylist: + - entity: {"item": "1.2.3.4", "type": "ip"} + expiration: %s + reason: over quota`, expirationTimeString), + futureTime, + []denyIOSpec{ + {DenyEntity{"1.2.3.4", IPAddrType}, nil}, + {DenyEntity{"61", ClusterType}, nil}, + {DenyEntity{"1.2.3.5", IPAddrType}, nil}, + }}, + } + // use cancel to prevent leaked goroutines from file watches + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tempDir := t.TempDir() + + filename := filepath.Join(tempDir, "denylist.yaml") + manualTime := timeutil.NewManualTime(startTime) + dl := NewDenylistWithFile(ctx, filename, WithPollingInterval(100*time.Millisecond)) + dl.timeSource = manualTime + for _, tc := range testCases { + require.NoError(t, ioutil.WriteFile(filename, []byte(tc.input), 0777)) + manualTime.AdvanceTo(tc.time) + time.Sleep(500 * time.Millisecond) + for _, ioPairs := range tc.specs { + actual, err := dl.Denied(ioPairs.entity) + require.NoError(t, err) + require.Equal(t, ioPairs.outcome, actual) + } + } +} diff --git a/pkg/ccl/sqlproxyccl/denylist/local_file.go b/pkg/ccl/sqlproxyccl/denylist/local_file.go index 8c4f67c3673c..34d56e04c07a 100644 --- a/pkg/ccl/sqlproxyccl/denylist/local_file.go +++ b/pkg/ccl/sqlproxyccl/denylist/local_file.go @@ -59,13 +59,13 @@ func newViperCfgFromFile(cfgFileName string) (*viper.Viper, error) { } // Denied implements the Service interface using viper to query the deny list. -func (d *viperDenyList) Denied(item string) (*Entry, error) { +func (d *viperDenyList) Denied(entity DenyEntity) (*Entry, error) { if d.mu.viperCfg == nil { return nil, nil } d.mu.Lock() defer d.mu.Unlock() - return d.deniedLocked(item) + return d.deniedLocked(entity.Item) } func (d *viperDenyList) deniedLocked(item string) (*Entry, error) { diff --git a/pkg/ccl/sqlproxyccl/denylist/local_file_test.go b/pkg/ccl/sqlproxyccl/denylist/local_file_test.go index 34551e379e00..749176147c91 100644 --- a/pkg/ccl/sqlproxyccl/denylist/local_file_test.go +++ b/pkg/ccl/sqlproxyccl/denylist/local_file_test.go @@ -16,9 +16,54 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/stretchr/testify/require" ) +// TestViperForwardCompatibility makes sure that the new config file format +// will not return error when ingested by the old binary, and will not +// cause sqlproxy fail to start. +func TestViperForwardCompatibility(t *testing.T) { + // Make sure that new config file format will not + // cause errors for the old binary. + defer leaktest.AfterTest(t)() + + files := []File{ + { + Seq: 3, + Denylist: []*DenyEntry{ + {DenyEntity{"1.1.1.1", IPAddrType}, timeutil.Now().Add(time.Hour), "some reason"}, + {DenyEntity{"63", ClusterType}, timeutil.Now().Add(2 * time.Hour), "another reason"}, + }, + }, + { + // empty file + }, + { + Seq: 7, + // empty list + }, + } + for _, file := range files { + cfgFile, err := ioutil.TempFile("", "*_denylist.yml") + require.NoError(t, err) + defer func() { _ = os.Remove(cfgFile.Name()) }() + + raw, err := file.Serialize() + require.NoError(t, err) + err = ioutil.WriteFile(cfgFile.Name(), raw, 0777) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // make sure old parser won't break on new config format + _, err = NewViperDenyListFromFile(ctx, cfgFile.Name(), + time.Minute) + require.NoError(t, err) + } +} + func TestViperDenyList(t *testing.T) { defer leaktest.AfterTest(t)() @@ -32,14 +77,14 @@ func TestViperDenyList(t *testing.T) { dl, err := NewViperDenyListFromFile(ctx, cfgFile.Name(), time.Millisecond) require.NoError(t, err) - e, err := dl.Denied("123") + e, err := dl.Denied(DenyEntity{"123", ClusterType}) require.NoError(t, err) require.True(t, e == nil) _, err = cfgFile.Write([]byte("456: denied")) require.NoError(t, err) time.Sleep(50 * time.Millisecond) - e, err = dl.Denied("456") + e, err = dl.Denied(DenyEntity{"456", ClusterType}) require.NoError(t, err) require.Equal(t, &Entry{Reason: "denied"}, e) } diff --git a/pkg/ccl/sqlproxyccl/denylist/mocks_generated.go b/pkg/ccl/sqlproxyccl/denylist/mocks_generated.go index 19b277a5654a..d40c42f4982b 100644 --- a/pkg/ccl/sqlproxyccl/denylist/mocks_generated.go +++ b/pkg/ccl/sqlproxyccl/denylist/mocks_generated.go @@ -34,16 +34,16 @@ func (m *MockService) EXPECT() *MockServiceMockRecorder { } // Denied mocks base method. -func (m *MockService) Denied(id string) (*Entry, error) { +func (m *MockService) Denied(entity DenyEntity) (*Entry, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Denied", id) + ret := m.ctrl.Call(m, "Denied", entity) ret0, _ := ret[0].(*Entry) ret1, _ := ret[1].(error) return ret0, ret1 } // Denied indicates an expected call of Denied. -func (mr *MockServiceMockRecorder) Denied(id interface{}) *gomock.Call { +func (mr *MockServiceMockRecorder) Denied(entity interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Denied", reflect.TypeOf((*MockService)(nil).Denied), id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Denied", reflect.TypeOf((*MockService)(nil).Denied), entity) } diff --git a/pkg/ccl/sqlproxyccl/denylist/service.go b/pkg/ccl/sqlproxyccl/denylist/service.go index 42cc8a12588a..5e15726b89c5 100644 --- a/pkg/ccl/sqlproxyccl/denylist/service.go +++ b/pkg/ccl/sqlproxyccl/denylist/service.go @@ -8,6 +8,8 @@ package denylist +import "strings" + //go:generate mockgen -package=denylist -destination=mocks_generated.go -source=service.go . Service // Entry records the reason for putting an item on the denylist. @@ -16,11 +18,71 @@ type Entry struct { Reason string } +// DenyEntity represent one denied entity. +// This also serves as the spec for the config format. +type DenyEntity struct { + Item string `yaml:"item"` + Type Type `yaml:"type"` +} + +// Type is the type of the denied entity. +type Type int + +// Enum values for Type. +const ( + IPAddrType Type = iota + 1 + ClusterType + UnknownType +) + +var strToTypeMap = map[string]Type{ + "ip": IPAddrType, + "cluster": ClusterType, +} + +var typeToStrMap = map[Type]string{ + IPAddrType: "ip", + ClusterType: "cluster", +} + +// UnmarshalYAML implements yaml.Unmarshaler interface for type. +func (typ *Type) UnmarshalYAML(unmarshal func(interface{}) error) error { + var raw string + err := unmarshal(&raw) + if err != nil { + return err + } + + normalized := strings.ToLower(raw) + t, ok := strToTypeMap[normalized] + if !ok { + *typ = UnknownType + } else { + *typ = t + } + + return nil +} + +// MarshalYAML implements yaml.Marshaler interface for type. +func (typ Type) MarshalYAML() (interface{}, error) { + return typ.String(), nil +} + +// String implements Stringer interface for type. +func (typ Type) String() string { + s, ok := typeToStrMap[typ] + if !ok { + return "UNKNOWN" + } + return s +} + // Service provides an interface for checking if an id has been denied access. type Service interface { // Denied returns a non-nil Entry if the id is denied. The reason for the // denial will be in Entry. - Denied(id string) (*Entry, error) + Denied(entity DenyEntity) (*Entry, error) // TODO(spaskob): add API for registering listeners to be notified of any // updates (inclusion/exclusion) to the denylist.