diff --git a/nameresolution/consul/README.md b/nameresolution/consul/README.md index a53824954b..d04eb870a0 100644 --- a/nameresolution/consul/README.md +++ b/nameresolution/consul/README.md @@ -1,6 +1,6 @@ # Consul Name Resolution -The consul name resolution component gives the ability to register and resolve other "daprized" services registered on a consul estate. It is flexible in that it allows for complex to minimal configurations driving the behaviour on init and resolution. +The consul name resolution component gives the ability to register and resolve other "daprized" services registered on a consul estate. It is flexible in that it allows for complex to minimal configurations driving the behavior on init and resolution. ## How To Use @@ -35,7 +35,7 @@ spec: ``` -## Behaviour +## Behavior On init the consul component will either validate the connection to the configured (or default) agent or register the service if configured to do so. The name resolution interface does not cater for an "on shutdown" pattern so please consider this if using Dapr to register services to consul as it will not deregister services. @@ -54,9 +54,10 @@ As of writing the configuration spec is fixed to v1.3.0 of the consul api | Tags | `[]string` | Configures any tags to include if/when registering services | | Meta | `map[string]string` | Configures any additional metadata to include if/when registering services | | DaprPortMetaKey | `string` | The key used for getting the Dapr sidecar port from consul service metadata during service resolution, it will also be used to set the Dapr sidecar port in metadata during registration. If blank it will default to `DAPR_PORT` | -| SelfRegister | `bool` | Controls if Dapr will register the service to consul. The name resolution interface does not cater for an "on shutdown" pattern so please consider this if using Dapr to register services to consul as it will not deregister services. | +| SelfRegister | `bool` | Controls if Dapr will register the service to consul on startup. If unset it will default to `false` | +| SelfDeregister | `bool` | Controls if Dapr will deregister the service from consul on shutdown. If unset it will default to `false` | | AdvancedRegistration | [*api.AgentServiceRegistration](https://pkg.go.dev/github.com/hashicorp/consul/api@v1.3.0#AgentServiceRegistration) | Gives full control of service registration through configuration. If configured the component will ignore any configuration of Checks, Tags, Meta and SelfRegister. | - +| UseCache | `bool` | Configures if Dapr will cache the resolved services in-memory. This is done using consul [blocking queries](https://www.consul.io/api-docs/features/blocking) which can be configured via the QueryOptions configuration. If unset it will default to `false` | ## Samples Configurations ### Basic diff --git a/nameresolution/consul/configuration.go b/nameresolution/consul/configuration.go index 64ba6d821a..9fba047479 100644 --- a/nameresolution/consul/configuration.go +++ b/nameresolution/consul/configuration.go @@ -23,6 +23,8 @@ import ( "github.com/dapr/kit/config" ) +const defaultDaprPortMetaKey string = "DAPR_PORT" // default key for DaprPort in meta + // The intermediateConfig is based off of the consul api types. User configurations are // deserialized into this type before being converted to the equivalent consul types // that way breaking changes in future versions of the consul api cannot break user configuration. @@ -33,8 +35,10 @@ type intermediateConfig struct { Meta map[string]string QueryOptions *QueryOptions AdvancedRegistration *AgentServiceRegistration // advanced use-case - SelfRegister bool DaprPortMetaKey string + SelfRegister bool + SelfDeregister bool + UseCache bool } type configSpec struct { @@ -44,8 +48,16 @@ type configSpec struct { Meta map[string]string QueryOptions *consul.QueryOptions AdvancedRegistration *consul.AgentServiceRegistration // advanced use-case - SelfRegister bool DaprPortMetaKey string + SelfRegister bool + SelfDeregister bool + UseCache bool +} + +func newIntermediateConfig() intermediateConfig { + return intermediateConfig{ + DaprPortMetaKey: defaultDaprPortMetaKey, + } } func parseConfig(rawConfig interface{}) (configSpec, error) { @@ -60,7 +72,7 @@ func parseConfig(rawConfig interface{}) (configSpec, error) { return result, fmt.Errorf("error serializing to json: %w", err) } - var configuration intermediateConfig + configuration := newIntermediateConfig() err = json.Unmarshal(data, &configuration) if err != nil { return result, fmt.Errorf("error deserializing to configSpec: %w", err) @@ -80,7 +92,9 @@ func mapConfig(config intermediateConfig) configSpec { QueryOptions: mapQueryOptions(config.QueryOptions), AdvancedRegistration: mapAdvancedRegistration(config.AdvancedRegistration), SelfRegister: config.SelfRegister, + SelfDeregister: config.SelfDeregister, DaprPortMetaKey: config.DaprPortMetaKey, + UseCache: config.UseCache, } } diff --git a/nameresolution/consul/consul.go b/nameresolution/consul/consul.go index e37e4db7c4..5ed9e071dc 100644 --- a/nameresolution/consul/consul.go +++ b/nameresolution/consul/consul.go @@ -18,6 +18,8 @@ import ( "math/rand" "net" "strconv" + "sync" + "sync/atomic" consul "github.com/hashicorp/consul/api" @@ -25,8 +27,6 @@ import ( "github.com/dapr/kit/logger" ) -const daprMeta string = "DAPR_PORT" // default key for DAPR_PORT metadata - type client struct { *consul.Client } @@ -59,34 +59,181 @@ type clientInterface interface { type agentInterface interface { Self() (map[string]map[string]interface{}, error) ServiceRegister(service *consul.AgentServiceRegistration) error + ServiceDeregister(serviceID string) error } type healthInterface interface { Service(service, tag string, passingOnly bool, q *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) + State(state string, q *consul.QueryOptions) (consul.HealthChecks, *consul.QueryMeta, error) } type resolver struct { - config resolverConfig - logger logger.Logger - client clientInterface + config resolverConfig + logger logger.Logger + client clientInterface + registry registryInterface + watcherStarted atomic.Bool + watcherStopChannel chan struct{} +} + +type registryInterface interface { + getKeys() []string + get(service string) *registryEntry + expire(service string) // clears slice of instances + expireAll() // clears slice of instances for all entries + remove(service string) // removes entry from registry + removeAll() // removes all entries from registry + addOrUpdate(service string, services []*consul.ServiceEntry) + registrationChannel() chan string +} + +type registry struct { + entries sync.Map + serviceChannel chan string +} + +type registryEntry struct { + services []*consul.ServiceEntry + mu sync.RWMutex +} + +func (r *registry) getKeys() []string { + var keys []string + r.entries.Range(func(key any, value any) bool { + k := key.(string) + keys = append(keys, k) + return true + }) + return keys +} + +func (r *registry) get(service string) *registryEntry { + if result, ok := r.entries.Load(service); ok { + return result.(*registryEntry) + } + + return nil +} + +func (e *registryEntry) next() *consul.ServiceEntry { + e.mu.Lock() + defer e.mu.Unlock() + + if len(e.services) == 0 { + return nil + } + + // gosec is complaining that we are using a non-crypto-safe PRNG. This is fine in this scenario since we are using it only for selecting a random address for load-balancing. + //nolint:gosec + return e.services[rand.Int()%len(e.services)] +} + +func (r *resolver) getService(service string) (*consul.ServiceEntry, error) { + var services []*consul.ServiceEntry + + if r.config.UseCache { + r.startWatcher() + + entry := r.registry.get(service) + if entry != nil { + result := entry.next() + + if result != nil { + return result, nil + } + } else { + r.registry.registrationChannel() <- service + } + } + + options := *r.config.QueryOptions + options.WaitHash = "" + options.WaitIndex = 0 + services, _, err := r.client.Health().Service(service, "", true, &options) + + if err != nil { + return nil, fmt.Errorf("failed to query healthy consul services: %w", err) + } else if len(services) == 0 { + return nil, fmt.Errorf("no healthy services found with AppID '%s'", service) + } + + //nolint:gosec + return services[rand.Int()%len(services)], nil +} + +func (r *registry) addOrUpdate(service string, services []*consul.ServiceEntry) { + // update + entry := r.get(service) + if entry != nil { + entry.mu.Lock() + defer entry.mu.Unlock() + + entry.services = services + + return + } + + // add + r.entries.Store(service, ®istryEntry{ + services: services, + }) +} + +func (r *registry) remove(service string) { + r.entries.Delete(service) +} + +func (r *registry) removeAll() { + r.entries.Range(func(key any, value any) bool { + r.remove(key.(string)) + return true + }) +} + +func (r *registry) expire(service string) { + entry := r.get(service) + if entry == nil { + return + } + + entry.mu.Lock() + defer entry.mu.Unlock() + + entry.services = nil +} + +func (r *registry) expireAll() { + r.entries.Range(func(key any, value any) bool { + r.expire(key.(string)) + return true + }) +} + +func (r *registry) registrationChannel() chan string { + return r.serviceChannel } type resolverConfig struct { - Client *consul.Config - QueryOptions *consul.QueryOptions - Registration *consul.AgentServiceRegistration - DaprPortMetaKey string + Client *consul.Config + QueryOptions *consul.QueryOptions + Registration *consul.AgentServiceRegistration + DeregisterOnClose bool + DaprPortMetaKey string + UseCache bool } // NewResolver creates Consul name resolver. func NewResolver(logger logger.Logger) nr.Resolver { - return newResolver(logger, &client{}) + return newResolver(logger, resolverConfig{}, &client{}, ®istry{serviceChannel: make(chan string, 100)}, make(chan struct{})) } -func newResolver(logger logger.Logger, client clientInterface) *resolver { +func newResolver(logger logger.Logger, resolverConfig resolverConfig, client clientInterface, registry registryInterface, watcherStopChannel chan struct{}) nr.Resolver { return &resolver{ - logger: logger, - client: client, + logger: logger, + config: resolverConfig, + client: client, + registry: registry, + watcherStopChannel: watcherStopChannel, } } @@ -129,23 +276,14 @@ func (r *resolver) Init(metadata nr.Metadata) (err error) { // ResolveID resolves name to address via consul. func (r *resolver) ResolveID(req nr.ResolveRequest) (addr string, err error) { cfg := r.config - services, _, err := r.client.Health().Service(req.ID, "", true, cfg.QueryOptions) + svc, err := r.getService(req.ID) if err != nil { - return "", fmt.Errorf("failed to query healthy consul services: %w", err) + return "", err } - if len(services) == 0 { - return "", fmt.Errorf("no healthy services found with AppID '%s'", req.ID) - } - - // Pick a random service from the result - // Note: we're using math/random here as PRNG and that's ok since we're just using this for selecting a random address from a list for load-balancing, so we don't need a CSPRNG - //nolint:gosec - svc := services[rand.Int()%len(services)] - port := svc.Service.Meta[cfg.DaprPortMetaKey] if port == "" { - return "", fmt.Errorf("target service AppID '%s' found but DAPR_PORT missing from meta", req.ID) + return "", fmt.Errorf("target service AppID '%s' found but %s missing from meta", req.ID, cfg.DaprPortMetaKey) } if svc.Service.Address != "" { @@ -159,6 +297,24 @@ func (r *resolver) ResolveID(req nr.ResolveRequest) (addr string, err error) { return formatAddress(addr, port) } +// Close will stop the watcher and deregister app from consul +func (r *resolver) Close() error { + if r.watcherStarted.Load() { + r.watcherStopChannel <- struct{}{} + } + + if r.config.Registration != nil && r.config.DeregisterOnClose { + err := r.client.Agent().ServiceDeregister(r.config.Registration.ID) + if err != nil { + return fmt.Errorf("failed to deregister consul service: %w", err) + } + + r.logger.Info("deregistered service from consul") + } + + return nil +} + func formatAddress(address string, port string) (addr string, err error) { if net.ParseIP(address).To4() != nil { return address + ":" + port, nil @@ -180,12 +336,9 @@ func getConfig(metadata nr.Metadata) (resolverCfg resolverConfig, err error) { return resolverCfg, err } - // set DaprPortMetaKey used for registring DaprPort and resolving from Consul - if cfg.DaprPortMetaKey == "" { - resolverCfg.DaprPortMetaKey = daprMeta - } else { - resolverCfg.DaprPortMetaKey = cfg.DaprPortMetaKey - } + resolverCfg.DaprPortMetaKey = cfg.DaprPortMetaKey + resolverCfg.DeregisterOnClose = cfg.SelfDeregister + resolverCfg.UseCache = cfg.UseCache resolverCfg.Client = getClientConfig(cfg) resolverCfg.Registration, err = getRegistrationConfig(cfg, metadata.Properties) diff --git a/nameresolution/consul/consul_test.go b/nameresolution/consul/consul_test.go index 3efe3c17a5..87fb3653a3 100644 --- a/nameresolution/consul/consul_test.go +++ b/nameresolution/consul/consul_test.go @@ -17,7 +17,9 @@ import ( "fmt" "net" "strconv" + "sync/atomic" "testing" + "time" consul "github.com/hashicorp/consul/api" "github.com/stretchr/testify/assert" @@ -50,24 +52,58 @@ func (m *mockClient) Agent() agentInterface { } type mockHealth struct { - serviceCalled int - serviceErr error - serviceResult []*consul.ServiceEntry - serviceMeta *consul.QueryMeta + serviceCalled int + serviceErr *error + serviceBehavior func(service, tag string, passingOnly bool, q *consul.QueryOptions) + serviceResult []*consul.ServiceEntry + serviceMeta *consul.QueryMeta + + stateCallStarted atomic.Int32 + stateCalled int + stateError *error + stateBehaviour func(state string, q *consul.QueryOptions) + stateResult consul.HealthChecks + stateMeta *consul.QueryMeta +} + +func (m *mockHealth) State(state string, q *consul.QueryOptions) (consul.HealthChecks, *consul.QueryMeta, error) { + m.stateCallStarted.Add(1) + + if m.stateBehaviour != nil { + m.stateBehaviour(state, q) + } + + m.stateCalled++ + + if m.stateError == nil { + return m.stateResult, m.stateMeta, nil + } + + return m.stateResult, m.stateMeta, *m.stateError } func (m *mockHealth) Service(service, tag string, passingOnly bool, q *consul.QueryOptions) ([]*consul.ServiceEntry, *consul.QueryMeta, error) { + if m.serviceBehavior != nil { + m.serviceBehavior(service, tag, passingOnly, q) + } + m.serviceCalled++ - return m.serviceResult, m.serviceMeta, m.serviceErr + if m.serviceErr == nil { + return m.serviceResult, m.serviceMeta, nil + } + + return m.serviceResult, m.serviceMeta, *m.serviceErr } type mockAgent struct { - selfCalled int - selfErr error - selfResult map[string]map[string]interface{} - serviceRegisterCalled int - serviceRegisterErr error + selfCalled int + selfErr error + selfResult map[string]map[string]interface{} + serviceRegisterCalled int + serviceRegisterErr error + serviceDeregisterCalled int + serviceDeregisterErr error } func (m *mockAgent) Self() (map[string]map[string]interface{}, error) { @@ -82,6 +118,71 @@ func (m *mockAgent) ServiceRegister(service *consul.AgentServiceRegistration) er return m.serviceRegisterErr } +func (m *mockAgent) ServiceDeregister(serviceID string) error { + m.serviceDeregisterCalled++ + + return m.serviceDeregisterErr +} + +type mockRegistry struct { + getKeysCalled atomic.Int32 + getKeysResult *[]string + getKeysBehaviour func() + addOrUpdateCalled atomic.Int32 + addOrUpdateBehaviour func(service string, services []*consul.ServiceEntry) + expireCalled int + expireAllCalled int + removeCalled int + removeAllCalled atomic.Int32 + getCalled int + getResult *registryEntry + registerChannelResult chan string +} + +func (m *mockRegistry) registrationChannel() chan string { + return m.registerChannelResult +} + +func (m *mockRegistry) getKeys() []string { + if m.getKeysBehaviour != nil { + m.getKeysBehaviour() + } + + m.getKeysCalled.Add(1) + + return *m.getKeysResult +} + +func (m *mockRegistry) expireAll() { + m.expireAllCalled++ +} + +func (m *mockRegistry) removeAll() { + m.removeAllCalled.Add(1) +} + +func (m *mockRegistry) addOrUpdate(service string, services []*consul.ServiceEntry) { + if m.addOrUpdateBehaviour != nil { + m.addOrUpdateBehaviour(service, services) + } + + m.addOrUpdateCalled.Add(1) +} + +func (m *mockRegistry) expire(service string) { + m.expireCalled++ +} + +func (m *mockRegistry) remove(service string) { + m.removeCalled++ +} + +func (m *mockRegistry) get(service string) *registryEntry { + m.getCalled++ + + return m.getResult +} + func TestInit(t *testing.T) { t.Parallel() @@ -99,7 +200,7 @@ func TestInit(t *testing.T) { t.Helper() var mock mockClient - resolver := newResolver(logger.NewLogger("test"), &mock) + resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, ®istry{}, make(chan struct{})) _ = resolver.Init(metadata) @@ -122,7 +223,7 @@ func TestInit(t *testing.T) { t.Helper() var mock mockClient - resolver := newResolver(logger.NewLogger("test"), &mock) + resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, ®istry{}, make(chan struct{})) _ = resolver.Init(metadata) @@ -144,7 +245,7 @@ func TestInit(t *testing.T) { t.Helper() var mock mockClient - resolver := newResolver(logger.NewLogger("test"), &mock) + resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, ®istry{}, make(chan struct{})) _ = resolver.Init(metadata) @@ -168,6 +269,7 @@ func TestResolveID(t *testing.T) { t.Parallel() testConfig := resolverConfig{ DaprPortMetaKey: "DAPR_PORT", + QueryOptions: &consul.QueryOptions{}, } tests := []struct { @@ -175,6 +277,521 @@ func TestResolveID(t *testing.T) { req nr.ResolveRequest test func(*testing.T, nr.ResolveRequest) }{ + { + "should use cache when enabled", + nr.ResolveRequest{ + ID: "test-app", + }, + func(t *testing.T, req nr.ResolveRequest) { + t.Helper() + + blockingCall := make(chan uint64) + meta := &consul.QueryMeta{ + LastIndex: 0, + } + + serviceEntries := []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + Meta: map[string]string{ + "DAPR_PORT": "50005", + }, + }, + }, + } + + cachedEntries := []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + Meta: map[string]string{ + "DAPR_PORT": "70007", + }, + }, + }, + } + + healthChecks := consul.HealthChecks{ + &consul.HealthCheck{ + Node: "0e1234", + ServiceID: "test-app-10.3.245.137-3500", + ServiceName: "test-app", + Status: consul.HealthPassing, + }, + } + + mock := &mockClient{ + mockHealth: mockHealth{ + // Service() + serviceResult: serviceEntries, + serviceMeta: meta, + serviceBehavior: func(service, tag string, passingOnly bool, q *consul.QueryOptions) { + }, + serviceErr: nil, + + // State() + stateResult: healthChecks, + stateMeta: meta, + stateBehaviour: func(state string, q *consul.QueryOptions) { + meta.LastIndex = <-blockingCall + }, + stateError: nil, + }, + } + + cfg := resolverConfig{ + DaprPortMetaKey: "DAPR_PORT", + UseCache: true, + QueryOptions: &consul.QueryOptions{}, + } + + serviceKeys := make([]string, 0, 10) + + mockReg := &mockRegistry{ + registerChannelResult: make(chan string, 100), + getKeysResult: &serviceKeys, + addOrUpdateBehaviour: func(service string, services []*consul.ServiceEntry) { + if services == nil { + serviceKeys = append(serviceKeys, service) + } + }, + } + resolver := newResolver(logger.NewLogger("test"), cfg, mock, mockReg, make(chan struct{})) + addr, _ := resolver.ResolveID(req) + + // no apps in registry - cache miss, call agent directly + assert.Equal(t, 1, mockReg.getCalled) + waitTillTrueOrTimeout(time.Second, func() bool { return mockReg.getKeysCalled.Load() == 2 }) + assert.Equal(t, 1, mock.mockHealth.serviceCalled) + assert.Equal(t, "10.3.245.137:50005", addr) + + // watcher adds app to registry + assert.Equal(t, int32(1), mockReg.addOrUpdateCalled.Load()) + assert.Equal(t, int32(2), mockReg.getKeysCalled.Load()) + + mockReg.registerChannelResult <- "test-app" + mockReg.getResult = ®istryEntry{ + services: cachedEntries, + } + + // blocking query - return new index + blockingCall <- 2 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 2 }) + assert.Equal(t, 1, mock.mockHealth.stateCalled) + + // get healthy nodes and update registry for service in result + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load()) + + // resolve id should only hit cache now + addr, _ = resolver.ResolveID(req) + assert.Equal(t, "10.3.245.137:70007", addr) + addr, _ = resolver.ResolveID(req) + assert.Equal(t, "10.3.245.137:70007", addr) + addr, _ = resolver.ResolveID(req) + assert.Equal(t, "10.3.245.137:70007", addr) + + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + assert.Equal(t, 4, mockReg.getCalled) + + // no update when no change in index and payload + blockingCall <- 2 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 3 }) + assert.Equal(t, 2, mock.mockHealth.stateCalled) + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load()) + + // no update when no change in payload + blockingCall <- 3 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 4 }) + assert.Equal(t, 3, mock.mockHealth.stateCalled) + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load()) + + // update when change in index and payload + mock.mockHealth.stateResult[0].Status = consul.HealthCritical + blockingCall <- 4 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 5 }) + assert.Equal(t, 4, mock.mockHealth.stateCalled) + assert.Equal(t, 3, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(3), mockReg.addOrUpdateCalled.Load()) + }, + }, + { + "should only update cache on change", + nr.ResolveRequest{ + ID: "test-app", + }, + func(t *testing.T, req nr.ResolveRequest) { + t.Helper() + + blockingCall := make(chan uint64) + meta := &consul.QueryMeta{} + + var err error + + // Node 1 all checks healthy + node1check1 := &consul.HealthCheck{ + Node: "0e1234", + ServiceID: "test-app-10.3.245.137-3500", + ServiceName: "test-app", + Status: consul.HealthPassing, + CheckID: "1", + } + + node1check2 := &consul.HealthCheck{ + Node: "0e1234", + ServiceID: "test-app-10.3.245.137-3500", + ServiceName: "test-app", + Status: consul.HealthPassing, + CheckID: "2", + } + + // Node 2 all checks unhealthy + node2check1 := &consul.HealthCheck{ + Node: "0e9878", + ServiceID: "test-app-10.3.245.127-3500", + ServiceName: "test-app", + Status: consul.HealthCritical, + CheckID: "1", + } + + node2check2 := &consul.HealthCheck{ + Node: "0e9878", + ServiceID: "test-app-10.3.245.127-3500", + ServiceName: "test-app", + Status: consul.HealthCritical, + CheckID: "2", + } + + mock := mockClient{ + mockHealth: mockHealth{ + // Service() + serviceResult: []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + Meta: map[string]string{ + "DAPR_PORT": "50005", + }, + }, + }, + }, + serviceMeta: meta, + serviceBehavior: nil, + serviceErr: &err, + + // State() + stateResult: consul.HealthChecks{ + node1check1, + node1check2, + node2check1, + node2check2, + }, + stateMeta: meta, + stateBehaviour: func(state string, q *consul.QueryOptions) { + meta.LastIndex = <-blockingCall + }, + stateError: nil, + }, + } + + cfg := resolverConfig{ + DaprPortMetaKey: "DAPR_PORT", + UseCache: true, + QueryOptions: &consul.QueryOptions{ + WaitIndex: 1, + }, + } + + serviceKeys := make([]string, 0, 10) + + mockReg := &mockRegistry{ + registerChannelResult: make(chan string, 100), + getKeysResult: &serviceKeys, + addOrUpdateBehaviour: func(service string, services []*consul.ServiceEntry) { + if services == nil { + serviceKeys = append(serviceKeys, service) + } + }, + } + resolver := newResolver(logger.NewLogger("test"), cfg, &mock, mockReg, make(chan struct{})) + addr, _ := resolver.ResolveID(req) + + // no apps in registry - cache miss, call agent directly + assert.Equal(t, 1, mockReg.getCalled) + waitTillTrueOrTimeout(time.Second, func() bool { return mockReg.addOrUpdateCalled.Load() == 1 }) + assert.Equal(t, 1, mock.mockHealth.serviceCalled) + assert.Equal(t, "10.3.245.137:50005", addr) + + // watcher adds app to registry + assert.Equal(t, int32(1), mockReg.addOrUpdateCalled.Load()) + assert.Equal(t, int32(2), mockReg.getKeysCalled.Load()) + + // add key to mock registry - trigger watcher + mockReg.registerChannelResult <- "test-app" + mockReg.getResult = ®istryEntry{ + services: mock.mockHealth.serviceResult, + } + + // blocking query - return new index + blockingCall <- 2 + waitTillTrueOrTimeout(time.Second, func() bool { return mockReg.addOrUpdateCalled.Load() == 2 }) + assert.Equal(t, 1, mock.mockHealth.stateCalled) + + // get healthy nodes and update registry for service in result + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load()) + + // resolve id should only hit cache now + _, _ = resolver.ResolveID(req) + _, _ = resolver.ResolveID(req) + _, _ = resolver.ResolveID(req) + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + + // change one check for node1 app to critical + node1check1.Status = consul.HealthCritical + + // blocking query - return new index - node1 app is now unhealthy + blockingCall <- 3 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 3 }) + assert.Equal(t, 2, mock.mockHealth.stateCalled) + assert.Equal(t, 3, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(3), mockReg.addOrUpdateCalled.Load()) + + // change remaining check for node1 app to critical + node1check2.Status = consul.HealthCritical + + // blocking query - return new index - node1 app is still unhealthy, no change + blockingCall <- 4 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 4 }) + assert.Equal(t, 3, mock.mockHealth.stateCalled) + assert.Equal(t, 3, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(3), mockReg.addOrUpdateCalled.Load()) + + // change one check for node2 app to healthy + node2check1.Status = consul.HealthPassing + + // blocking query - return new index - node2 app is still unhealthy, no change + blockingCall <- 4 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 5 }) + assert.Equal(t, 4, mock.mockHealth.stateCalled) + assert.Equal(t, 3, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(3), mockReg.addOrUpdateCalled.Load()) + + // change remaining check for node2 app to healthy + node2check2.Status = consul.HealthPassing + + // blocking query - return new index - node2 app is now healthy + blockingCall <- 5 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 6 }) + assert.Equal(t, 5, mock.mockHealth.stateCalled) + assert.Equal(t, 4, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(4), mockReg.addOrUpdateCalled.Load()) + }, + }, + { + "should expire cache upon blocking call error", + nr.ResolveRequest{ + ID: "test-app", + }, + func(t *testing.T, req nr.ResolveRequest) { + t.Helper() + + blockingCall := make(chan uint64) + meta := &consul.QueryMeta{ + LastIndex: 0, + } + + err := fmt.Errorf("oh no") + + serviceEntries := []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + Meta: map[string]string{ + "DAPR_PORT": "50005", + }, + }, + }, + } + + healthChecks := consul.HealthChecks{ + &consul.HealthCheck{ + Node: "0e1234", + ServiceID: "test-app-10.3.245.137-3500", + ServiceName: "test-app", + Status: consul.HealthPassing, + }, + } + + mock := &mockClient{ + mockHealth: mockHealth{ + // Service() + serviceResult: serviceEntries, + serviceMeta: meta, + serviceBehavior: func(service, tag string, passingOnly bool, q *consul.QueryOptions) { + }, + serviceErr: nil, + + // State() + stateResult: healthChecks, + stateMeta: meta, + stateBehaviour: func(state string, q *consul.QueryOptions) { + meta.LastIndex = <-blockingCall + }, + stateError: nil, + }, + } + + cfg := resolverConfig{ + DaprPortMetaKey: "DAPR_PORT", + UseCache: true, + QueryOptions: &consul.QueryOptions{}, + } + + serviceKeys := make([]string, 0, 10) + + mockReg := &mockRegistry{ + registerChannelResult: make(chan string, 100), + getKeysResult: &serviceKeys, + addOrUpdateBehaviour: func(service string, services []*consul.ServiceEntry) { + if services == nil { + serviceKeys = append(serviceKeys, service) + } + }, + } + resolver := newResolver(logger.NewLogger("test"), cfg, mock, mockReg, make(chan struct{})) + addr, _ := resolver.ResolveID(req) + + // Cache miss pass through + assert.Equal(t, 1, mockReg.getCalled) + waitTillTrueOrTimeout(time.Second, func() bool { return mockReg.addOrUpdateCalled.Load() == 1 }) + assert.Equal(t, 1, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(1), mockReg.addOrUpdateCalled.Load()) + assert.Equal(t, "10.3.245.137:50005", addr) + + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 1 }) + mockReg.getKeysResult = &serviceKeys + mockReg.registerChannelResult <- "test-app" + mockReg.getResult = ®istryEntry{ + services: serviceEntries, + } + + blockingCall <- 2 + waitTillTrueOrTimeout(time.Second, func() bool { return mockReg.addOrUpdateCalled.Load() == 2 }) + assert.Equal(t, 1, mock.mockHealth.stateCalled) + assert.Equal(t, 2, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(2), mockReg.addOrUpdateCalled.Load()) + + mock.mockHealth.stateError = &err + blockingCall <- 3 + blockingCall <- 3 + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 2 }) + assert.Equal(t, 1, mockReg.expireAllCalled) + }, + }, + { + "should stop watcher on close", + nr.ResolveRequest{ + ID: "test-app", + }, + func(t *testing.T, req nr.ResolveRequest) { + t.Helper() + + blockingCall := make(chan uint64) + meta := &consul.QueryMeta{ + LastIndex: 0, + } + + serviceEntries := []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + Meta: map[string]string{ + "DAPR_PORT": "50005", + }, + }, + }, + } + + healthChecks := consul.HealthChecks{ + &consul.HealthCheck{ + Node: "0e1234", + ServiceID: "test-app-10.3.245.137-3500", + ServiceName: "test-app", + Status: consul.HealthPassing, + }, + } + + mock := &mockClient{ + mockHealth: mockHealth{ + // Service() + serviceResult: serviceEntries, + serviceMeta: meta, + serviceBehavior: func(service, tag string, passingOnly bool, q *consul.QueryOptions) { + }, + serviceErr: nil, + + // State() + stateResult: healthChecks, + stateMeta: meta, + stateBehaviour: func(state string, q *consul.QueryOptions) { + select { + case meta.LastIndex = <-blockingCall: + case <-q.Context().Done(): + } + }, + stateError: nil, + }, + } + + cfg := resolverConfig{ + DaprPortMetaKey: "DAPR_PORT", + UseCache: true, + QueryOptions: &consul.QueryOptions{}, + } + + serviceKeys := make([]string, 0, 10) + + mockReg := &mockRegistry{ + registerChannelResult: make(chan string, 100), + getKeysResult: &serviceKeys, + addOrUpdateBehaviour: func(service string, services []*consul.ServiceEntry) { + if services == nil { + serviceKeys = append(serviceKeys, service) + } + }, + } + resolver := newResolver(logger.NewLogger("test"), cfg, mock, mockReg, make(chan struct{})).(*resolver) + addr, _ := resolver.ResolveID(req) + + // Cache miss pass through + assert.Equal(t, 1, mockReg.getCalled) + waitTillTrueOrTimeout(time.Second, func() bool { return mockReg.addOrUpdateCalled.Load() == 1 }) + assert.Equal(t, 1, mock.mockHealth.serviceCalled) + assert.Equal(t, int32(1), mockReg.addOrUpdateCalled.Load()) + assert.Equal(t, "10.3.245.137:50005", addr) + + waitTillTrueOrTimeout(time.Second, func() bool { return mock.mockHealth.stateCallStarted.Load() == 1 }) + mockReg.getKeysResult = &serviceKeys + mockReg.registerChannelResult <- "test-app" + mockReg.getResult = ®istryEntry{ + services: serviceEntries, + } + + resolver.Close() + waitTillTrueOrTimeout(time.Second*1, func() bool { return mockReg.removeAllCalled.Load() == 1 }) + assert.Equal(t, int32(1), mockReg.removeAllCalled.Load()) + assert.Equal(t, false, resolver.watcherStarted.Load()) + }, + }, { "error if no healthy services found", nr.ResolveRequest{ @@ -187,8 +804,7 @@ func TestResolveID(t *testing.T) { serviceResult: []*consul.ServiceEntry{}, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) _, err := resolver.ResolveID(req) assert.Equal(t, 1, mock.mockHealth.serviceCalled) @@ -207,7 +823,7 @@ func TestResolveID(t *testing.T) { serviceResult: []*consul.ServiceEntry{ { Service: &consul.AgentService{ - Address: "123.234.245.255", + Address: "10.3.245.137", Port: 8600, Meta: map[string]string{ "DAPR_PORT": "50005", @@ -217,12 +833,11 @@ func TestResolveID(t *testing.T) { }, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) addr, _ := resolver.ResolveID(req) - assert.Equal(t, "123.234.245.255:50005", addr) + assert.Equal(t, "10.3.245.137:50005", addr) }, }, { @@ -247,8 +862,7 @@ func TestResolveID(t *testing.T) { }, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) addr, _ := resolver.ResolveID(req) @@ -267,7 +881,7 @@ func TestResolveID(t *testing.T) { serviceResult: []*consul.ServiceEntry{ { Service: &consul.AgentService{ - Address: "123.234.245.255", + Address: "10.3.245.137", Port: 8600, Meta: map[string]string{ "DAPR_PORT": "50005", @@ -286,15 +900,14 @@ func TestResolveID(t *testing.T) { }, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) total1 := 0 total2 := 0 for i := 0; i < 100; i++ { addr, _ := resolver.ResolveID(req) - if addr == "123.234.245.255:50005" { + if addr == "10.3.245.137:50005" { total1++ } else if addr == "234.245.255.228:50005" { total2++ @@ -321,7 +934,7 @@ func TestResolveID(t *testing.T) { serviceResult: []*consul.ServiceEntry{ { Node: &consul.Node{ - Address: "123.234.245.255", + Address: "10.3.245.137", }, Service: &consul.AgentService{ Address: "", @@ -333,7 +946,7 @@ func TestResolveID(t *testing.T) { }, { Node: &consul.Node{ - Address: "123.234.245.255", + Address: "10.3.245.137", }, Service: &consul.AgentService{ Address: "", @@ -346,12 +959,11 @@ func TestResolveID(t *testing.T) { }, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) addr, _ := resolver.ResolveID(req) - assert.Equal(t, "123.234.245.255:50005", addr) + assert.Equal(t, "10.3.245.137:50005", addr) }, }, { @@ -376,8 +988,7 @@ func TestResolveID(t *testing.T) { }, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) _, err := resolver.ResolveID(req) @@ -403,8 +1014,7 @@ func TestResolveID(t *testing.T) { }, }, } - resolver := newResolver(logger.NewLogger("test"), &mock) - resolver.config = testConfig + resolver := newResolver(logger.NewLogger("test"), testConfig, &mock, ®istry{}, make(chan struct{})) _, err := resolver.ResolveID(req) @@ -416,12 +1026,277 @@ func TestResolveID(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.testName, func(t *testing.T) { - t.Parallel() tt.test(t, tt.req) }) } } +func TestClose(t *testing.T) { + t.Parallel() + + tests := []struct { + testName string + metadata nr.Metadata + test func(*testing.T, nr.Metadata) + }{ + { + "should deregister", + nr.Metadata{Base: metadata.Base{ + Properties: getTestPropsWithoutKey(""), + }, Configuration: nil}, + func(t *testing.T, metadata nr.Metadata) { + t.Helper() + + var mock mockClient + cfg := resolverConfig{ + Registration: &consul.AgentServiceRegistration{}, + DeregisterOnClose: true, + } + + resolver := newResolver(logger.NewLogger("test"), cfg, &mock, ®istry{}, make(chan struct{})).(*resolver) + resolver.Close() + + assert.Equal(t, 1, mock.mockAgent.serviceDeregisterCalled) + }, + }, + { + "should not deregister", + nr.Metadata{Base: metadata.Base{ + Properties: getTestPropsWithoutKey(""), + }, Configuration: nil}, + func(t *testing.T, metadata nr.Metadata) { + t.Helper() + + var mock mockClient + cfg := resolverConfig{ + Registration: &consul.AgentServiceRegistration{}, + DeregisterOnClose: false, + } + + resolver := newResolver(logger.NewLogger("test"), cfg, &mock, ®istry{}, make(chan struct{})).(*resolver) + resolver.Close() + + assert.Equal(t, 0, mock.mockAgent.serviceDeregisterCalled) + }, + }, + { + "should not deregister when no registration", + nr.Metadata{Base: metadata.Base{ + Properties: getTestPropsWithoutKey(""), + }, Configuration: nil}, + func(t *testing.T, metadata nr.Metadata) { + t.Helper() + + var mock mockClient + cfg := resolverConfig{ + Registration: nil, + DeregisterOnClose: true, + } + + resolver := newResolver(logger.NewLogger("test"), cfg, &mock, ®istry{}, make(chan struct{})).(*resolver) + resolver.Close() + + assert.Equal(t, 0, mock.mockAgent.serviceDeregisterCalled) + }, + }, + { + "should stop watcher if started", + nr.Metadata{Base: metadata.Base{ + Properties: getTestPropsWithoutKey(""), + }, Configuration: nil}, + func(t *testing.T, metadata nr.Metadata) { + t.Helper() + + var mock mockClient + resolver := newResolver(logger.NewLogger("test"), resolverConfig{}, &mock, ®istry{}, make(chan struct{})).(*resolver) + resolver.watcherStarted.Store(true) + + go resolver.Close() + + sleepTimer := time.NewTimer(time.Second) + watcherStoppedInItem := false + select { + case <-sleepTimer.C: + case <-resolver.watcherStopChannel: + watcherStoppedInItem = true + } + + assert.True(t, watcherStoppedInItem) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.testName, func(t *testing.T) { + t.Parallel() + tt.test(t, tt.metadata) + }) + } +} + +func TestRegistry(t *testing.T) { + t.Parallel() + + appID := "myService" + tests := []struct { + testName string + test func(*testing.T) + }{ + { + "should add and update entry", + func(t *testing.T) { + t.Helper() + + registry := ®istry{} + + result := []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + }, + }, + } + + registry.addOrUpdate(appID, result) + + entry, _ := registry.entries.Load(appID) + assert.Equal(t, result, entry.(*registryEntry).services) + + update := []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "random", + Port: 123, + }, + }, + } + + registry.addOrUpdate(appID, update) + entry, _ = registry.entries.Load(appID) + assert.Equal(t, update, entry.(*registryEntry).services) + }, + }, + { + "should expire entries", + func(t *testing.T) { + t.Helper() + + registry := ®istry{} + registry.entries.Store( + "A", + ®istryEntry{ + services: []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + }, + }, + }, + }) + + registry.entries.Store( + "B", + ®istryEntry{ + services: []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + }, + }, + }, + }) + + registry.entries.Store( + "C", + ®istryEntry{ + services: []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + }, + }, + }, + }) + + result, _ := registry.entries.Load("A") + assert.NotNil(t, result.(*registryEntry).services) + + registry.expire("A") + + result, _ = registry.entries.Load("A") + assert.Nil(t, result.(*registryEntry).services) + + registry.expireAll() + count := 0 + nilCount := 0 + registry.entries.Range(func(key, value any) bool { + count++ + if value.(*registryEntry).services == nil { + nilCount++ + } + return true + }) + + assert.Equal(t, 3, count) + assert.Equal(t, 3, nilCount) + }, + }, + { + "should remove entry", + func(t *testing.T) { + t.Helper() + + registry := ®istry{} + entry := ®istryEntry{ + services: []*consul.ServiceEntry{ + { + Service: &consul.AgentService{ + Address: "10.3.245.137", + Port: 8600, + }, + }, + }, + } + + registry.entries.Store("A", entry) + registry.entries.Store("B", entry) + registry.entries.Store("C", entry) + registry.entries.Store("D", entry) + + registry.remove("A") + + result, _ := registry.entries.Load("A") + assert.Nil(t, result) + + result, _ = registry.entries.Load("B") + assert.NotNil(t, result) + + registry.removeAll() + count := 0 + registry.entries.Range(func(key, value any) bool { + count++ + return true + }) + + assert.Equal(t, 0, count) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.testName, func(t *testing.T) { + t.Parallel() + tt.test(t) + }) + } +} + func TestParseConfig(t *testing.T) { t.Parallel() @@ -456,6 +1331,8 @@ func TestParseConfig(t *testing.T) { "UseCache": true, "Filter": "Checks.ServiceTags contains dapr", }, + "DaprPortMetaKey": "DAPR_PORT", + "UseCache": false, }, configSpec{ Checks: []*consul.AgentServiceCheck{ @@ -479,13 +1356,17 @@ func TestParseConfig(t *testing.T) { UseCache: true, Filter: "Checks.ServiceTags contains dapr", }, + DaprPortMetaKey: "DAPR_PORT", + UseCache: false, }, }, { "empty configuration in metadata", true, nil, - configSpec{}, + configSpec{ + DaprPortMetaKey: defaultDaprPortMetaKey, + }, }, { "fail on unsupported map key", @@ -545,15 +1426,18 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, true, actual.QueryOptions.UseCache) // DaprPortMetaKey - assert.Equal(t, "DAPR_PORT", actual.DaprPortMetaKey) + assert.Equal(t, defaultDaprPortMetaKey, actual.DaprPortMetaKey) + + // Cache + assert.Equal(t, false, actual.UseCache) }, }, { "empty configuration with SelfRegister should default correctly", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey("")}, - Configuration: configSpec{ - SelfRegister: true, + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -572,22 +1456,25 @@ func TestGetConfig(t *testing.T) { // Metadata assert.Equal(t, 1, len(actual.Registration.Meta)) - assert.Equal(t, "50001", actual.Registration.Meta["DAPR_PORT"]) + assert.Equal(t, "50001", actual.Registration.Meta[actual.DaprPortMetaKey]) // QueryOptions assert.Equal(t, true, actual.QueryOptions.UseCache) // DaprPortMetaKey - assert.Equal(t, "DAPR_PORT", actual.DaprPortMetaKey) + assert.Equal(t, defaultDaprPortMetaKey, actual.DaprPortMetaKey) + + // Cache + assert.Equal(t, false, actual.UseCache) }, }, { "DaprPortMetaKey should set registration meta and config used for resolve", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey("")}, - Configuration: configSpec{ - SelfRegister: true, - DaprPortMetaKey: "random_key", + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, + "DaprPortMetaKey": "random_key", }, }, func(t *testing.T, metadata nr.Metadata) { @@ -600,12 +1487,28 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, daprPort, actual.Registration.Meta["random_key"]) }, }, + { + "SelfDeregister should set DeregisterOnClose", + nr.Metadata{ + Base: metadata.Base{Properties: getTestPropsWithoutKey("")}, + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, + "SelfDeregister": true, + }, + }, + func(t *testing.T, metadata nr.Metadata) { + t.Helper() + actual, _ := getConfig(metadata) + + assert.Equal(t, true, actual.DeregisterOnClose) + }, + }, { "missing AppID property should error when SelfRegister true", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.AppID)}, - Configuration: configSpec{ - SelfRegister: true, + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -634,8 +1537,8 @@ func TestGetConfig(t *testing.T) { "missing AppPort property should error when SelfRegister true", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.AppPort)}, - Configuration: configSpec{ - SelfRegister: true, + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -664,8 +1567,8 @@ func TestGetConfig(t *testing.T) { "missing HostAddress property should error when SelfRegister true", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.HostAddress)}, - Configuration: configSpec{ - SelfRegister: true, + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -694,8 +1597,8 @@ func TestGetConfig(t *testing.T) { "missing DaprHTTPPort property should error only when SelfRegister true", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey(nr.DaprHTTPPort)}, - Configuration: configSpec{ - SelfRegister: true, + Configuration: map[interface{}]interface{}{ + "SelfRegister": true, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -757,27 +1660,29 @@ func TestGetConfig(t *testing.T) { "registration should configure correctly", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey("")}, - Configuration: configSpec{ - Checks: []*consul.AgentServiceCheck{ - { - Name: "test-app health check name", - CheckID: "test-app health check id", - Interval: "15s", - HTTP: "http://127.0.0.1:3500/health", + Configuration: map[interface{}]interface{}{ + "Checks": []interface{}{ + map[interface{}]interface{}{ + "Name": "test-app health check name", + "CheckID": "test-app health check id", + "Interval": "15s", + "HTTP": "http://127.0.0.1:3500/health", }, }, - Tags: []string{ + "Tags": []interface{}{ "test", }, - Meta: map[string]string{ + "Meta": map[interface{}]interface{}{ "APP_PORT": "8650", "DAPR_GRPC_PORT": "50005", }, - QueryOptions: &consul.QueryOptions{ - UseCache: false, - Filter: "Checks.ServiceTags contains something", + "QueryOptions": map[interface{}]interface{}{ + "UseCache": false, + "Filter": "Checks.ServiceTags contains something", }, - SelfRegister: true, + "SelfRegister": true, + "DaprPortMetaKey": "PORT", + "UseCache": false, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -798,50 +1703,53 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, "test", actual.Registration.Tags[0]) assert.Equal(t, "8650", actual.Registration.Meta["APP_PORT"]) assert.Equal(t, "50005", actual.Registration.Meta["DAPR_GRPC_PORT"]) + assert.Equal(t, metadata.Properties[nr.DaprPort], actual.Registration.Meta["PORT"]) assert.Equal(t, false, actual.QueryOptions.UseCache) assert.Equal(t, "Checks.ServiceTags contains something", actual.QueryOptions.Filter) + assert.Equal(t, "PORT", actual.DaprPortMetaKey) + assert.Equal(t, false, actual.UseCache) }, }, { "advanced registration should override/ignore other configs", nr.Metadata{ Base: metadata.Base{Properties: getTestPropsWithoutKey("")}, - Configuration: configSpec{ - AdvancedRegistration: &consul.AgentServiceRegistration{ - Name: "random-app-id", - Port: 0o00, - Address: "123.345.678", - Tags: []string{"random-tag"}, - Meta: map[string]string{ + Configuration: map[interface{}]interface{}{ + "AdvancedRegistration": map[interface{}]interface{}{ + "Name": "random-app-id", + "Port": 0o00, + "Address": "123.345.678", + "Tags": []string{"random-tag"}, + "Meta": map[string]string{ "APP_PORT": "000", }, - Checks: []*consul.AgentServiceCheck{ - { - Name: "random health check name", - CheckID: "random health check id", - Interval: "15s", - HTTP: "http://127.0.0.1:3500/health", + "Checks": []interface{}{ + map[interface{}]interface{}{ + "Name": "random health check name", + "CheckID": "random health check id", + "Interval": "15s", + "HTTP": "http://127.0.0.1:3500/health", }, }, }, - Checks: []*consul.AgentServiceCheck{ - { - Name: "test-app health check name", - CheckID: "test-app health check id", - Interval: "15s", - HTTP: "http://127.0.0.1:3500/health", + "Checks": []interface{}{ + map[interface{}]interface{}{ + "Name": "test-app health check name", + "CheckID": "test-app health check id", + "Interval": "15s", + "HTTP": "http://127.0.0.1:3500/health", }, }, - Tags: []string{ + "Tags": []string{ "dapr", "test", }, - Meta: map[string]string{ + "Meta": map[string]string{ "APP_PORT": "123", "DAPR_HTTP_PORT": "3500", "DAPR_GRPC_PORT": "50005", }, - SelfRegister: false, + "SelfRegister": false, }, }, func(t *testing.T, metadata nr.Metadata) { @@ -1145,6 +2053,7 @@ func TestMapConfig(t *testing.T) { }, SelfRegister: true, DaprPortMetaKey: "SOMETHINGSOMETHING", + UseCache: false, } actual := mapConfig(expected) @@ -1161,6 +2070,7 @@ func TestMapConfig(t *testing.T) { assert.Equal(t, expected.Meta, actual.Meta) assert.Equal(t, expected.SelfRegister, actual.SelfRegister) assert.Equal(t, expected.DaprPortMetaKey, actual.DaprPortMetaKey) + assert.Equal(t, expected.UseCache, actual.UseCache) }) t.Run("should map empty configuration", func(t *testing.T) { @@ -1317,3 +2227,13 @@ func getTestPropsWithoutKey(removeKey string) map[string]string { return metadata } + +func waitTillTrueOrTimeout(d time.Duration, condition func() bool) { + for i := 0; i < 100; i++ { + if condition() { + return + } + + time.Sleep(d / 100) + } +} diff --git a/nameresolution/consul/watcher.go b/nameresolution/consul/watcher.go new file mode 100644 index 0000000000..ece97ab05d --- /dev/null +++ b/nameresolution/consul/watcher.go @@ -0,0 +1,364 @@ +package consul + +import ( + "context" + "errors" + "strings" + "time" + + backoff "github.com/cenkalti/backoff/v4" + consul "github.com/hashicorp/consul/api" +) + +const ( + // initial back interval. + initialBackOffInternal = 5 * time.Second + + // maximum back off time, this is to prevent exponential runaway. + maxBackOffInternal = 180 * time.Second +) + +// A watchPlan contains all the state tracked in the loop +// that keeps the consul service registry cache fresh +type watchPlan struct { + expired bool + lastParamVal blockingParamVal + lastResult map[serviceIdentifier]bool + options *consul.QueryOptions + healthServiceQueryFilter string + failing bool + backOff *backoff.ExponentialBackOff +} + +type blockingParamVal interface { + equal(other blockingParamVal) bool + next(previous blockingParamVal) blockingParamVal +} + +type waitIndexVal uint64 + +// Equal implements BlockingParamVal. +func (idx waitIndexVal) equal(other blockingParamVal) bool { + if otherIdx, ok := other.(waitIndexVal); ok { + return idx == otherIdx + } + + return false +} + +// Next implements BlockingParamVal. +func (idx waitIndexVal) next(previous blockingParamVal) blockingParamVal { + if previous == nil { + return idx + } + prevIdx, ok := previous.(waitIndexVal) + if ok && prevIdx == idx { + // this value is the same as the previous index, reset + return waitIndexVal(0) + } + + return idx +} + +type serviceIdentifier struct { + serviceName string + serviceID string + node string +} + +func getHealthByService(checks consul.HealthChecks) map[serviceIdentifier]bool { + healthByService := make(map[serviceIdentifier]bool) + for _, check := range checks { + // generate unique identifier for service + id := serviceIdentifier{ + serviceID: check.ServiceID, + serviceName: check.ServiceName, + node: check.Node, + } + + // if the service is not in the map - add and init to healthy + if state, ok := healthByService[id]; !ok { + healthByService[id] = true + } else if !state { + // service exists and is already unhealthy - skip + continue + } + + // if the check is not healthy then set service to unhealthy + if check.Status != consul.HealthPassing { + healthByService[id] = false + } + } + + return healthByService +} + +func (p *watchPlan) getChangedServices(newResult map[serviceIdentifier]bool) map[string]struct{} { + changedServices := make(map[string]struct{}) // service name set + + // foreach new result + for newKey, newValue := range newResult { + // if the service exists in the old result and has the same value - skip + if oldValue, ok := p.lastResult[newKey]; ok && newValue == oldValue { + continue + } + + // service is new or changed - add to set + changedServices[newKey.serviceName] = struct{}{} + } + + // foreach old result + for oldKey := range p.lastResult { + // if the service does not exist in the new result - add to set + if _, ok := newResult[oldKey]; !ok { + changedServices[oldKey.serviceName] = struct{}{} + } + } + + return changedServices +} + +func getServiceNameFilter(services []string) string { + nameFilters := make([]string, len(services)) + + for i, v := range services { + nameFilters[i] = `ServiceName=="` + v + `"` + } + + return strings.Join(nameFilters, " or ") +} + +func (r *resolver) watch(ctx context.Context, p *watchPlan, services []string) (blockingParamVal, consul.HealthChecks, error) { + p.options = p.options.WithContext(ctx) + + if p.lastParamVal != nil { + p.options.WaitIndex = uint64(p.lastParamVal.(waitIndexVal)) + } + + // build service name filter for all keys + p.options.Filter = getServiceNameFilter(services) + + // request health checks for target services using blocking query + checks, meta, err := r.client.Health().State(consul.HealthAny, p.options) + if err != nil { + // if it failed during long poll try again with no wait + if p.options.WaitIndex != uint64(0) { + p.options.WaitIndex = 0 + checks, meta, err = r.client.Health().State(consul.HealthAny, p.options) + } + + if err != nil { + // if the context was canceled + if errors.Is(err, context.Canceled) { + return nil, nil, err + } + + // if it failed with no wait and plan is not expired + if p.options.WaitIndex == uint64(0) && !p.expired { + p.lastResult = nil + p.expired = true + r.registry.expireAll() + } + + return nil, nil, err + } + } + + p.expired = false + return waitIndexVal(meta.LastIndex), checks, err +} + +// runWatchPlan executes the following steps: +// - requests health check changes for the target keys from the consul agent using http long polling +// - compares the results to the previous +// - if there is a change for a given serviceName/appId it invokes the health/service api to get a list of healthy targets +// - signals completion of the watch plan +func (r *resolver) runWatchPlan(ctx context.Context, p *watchPlan, services []string, watchPlanComplete chan struct{}) { + defer func() { + // signal completion of the watch plan to unblock the watch plan loop + watchPlanComplete <- struct{}{} + }() + + // invoke blocking call + blockParam, result, err := r.watch(ctx, p, services) + + // if the ctx was canceled then do nothing + if errors.Is(err, context.Canceled) { + return + } + + // handle an error in the watch function + if err != nil { + // reset the query index so the next attempt does not + p.lastParamVal = waitIndexVal(0) + + // perform an exponential backoff + if !p.failing { + p.failing = true + p.backOff.Reset() + } + + retry := p.backOff.NextBackOff() + + // pause watcher routine until ctx is canceled or retry timer finishes + r.logger.Errorf("consul service-watcher error: %v, retry in %s", err, retry.Round(time.Second)) + sleepTimer := time.NewTimer(retry) + select { + case <-ctx.Done(): + sleepTimer.Stop() + r.logger.Debug("consul service-watcher retry throttling canceled") + case <-sleepTimer.C: + } + + return + } else { + // reset the plan failure flag + p.failing = false + } + + // if the result index is unchanged do nothing + if p.lastParamVal != nil && p.lastParamVal.equal(blockParam) { + return + } else { + // update the plan index + oldParamVal := p.lastParamVal + p.lastParamVal = blockParam.next(oldParamVal) + } + + // compare last and new result to get changed services + healthByService := getHealthByService(result) + changedServices := p.getChangedServices(healthByService) + + // update the plan last result + p.lastResult = healthByService + + // call agent to get updated healthy nodes for each changed service + for k := range changedServices { + p.options.WaitIndex = 0 + p.options.Filter = p.healthServiceQueryFilter + p.options = p.options.WithContext(ctx) + result, meta, err := r.client.Health().Service(k, "", true, p.options) + + if err != nil { + // on failure, expire service from cache, resolver will fall back to agent + r.logger.Errorf("error invoking health service: %v, for service %s", err, k) + r.registry.expire(k) + + // remove healthchecks for service from last result + for key := range p.lastResult { + if k == key.serviceName { + delete(p.lastResult, key) + } + } + + // reset plan query index + p.lastParamVal = waitIndexVal(0) + } else { + // updated service entries in registry + r.logger.Debugf("updating consul nr registry for service:%s last-index:%d", k, meta.LastIndex) + r.registry.addOrUpdate(k, result) + } + } +} + +// runWatchLoop executes the following steps in a forever loop: +// - gets the keys from the registry +// - executes the watch plan with the targets keys +// - waits for (the watch plan to signal completion) or (the resolver to register a new key) +func (r *resolver) runWatchLoop(p *watchPlan) { + defer func() { + r.registry.removeAll() + r.watcherStarted.Store(false) + }() + + watchPlanComplete := make(chan struct{}, 1) + +watchLoop: + for { + ctx, cancel := context.WithCancel(context.Background()) + + // get target keys/app-ids from registry + services := r.registry.getKeys() + watching := false + + if len(services) > 0 { + // run watch plan for targets service with channel to signal completion + go r.runWatchPlan(ctx, p, services, watchPlanComplete) + watching = true + } + + select { + case <-watchPlanComplete: + cancel() + + // wait on channel for new services to track + case service := <-r.registry.registrationChannel(): + // cancel watch plan i.e. blocking query to consul agent + cancel() + + // generate set of keys + serviceKeys := make(map[string]any) + for i := 0; i < len(services); i++ { + serviceKeys[services[i]] = nil + } + + // add service if it's not in the registry + if _, ok := serviceKeys[service]; !ok { + r.registry.addOrUpdate(service, nil) + } + + // check for any more new services in channel and do the same + moreServices := true + for moreServices { + select { + case service := <-r.registry.registrationChannel(): + if _, ok := serviceKeys[service]; !ok { + r.registry.addOrUpdate(service, nil) + } + default: + moreServices = false + } + } + + if watching { + // ensure previous watch plan routine completed before next iteration + <-watchPlanComplete + } + + // reset plan failure count and query index + p.failing = false + p.lastParamVal = waitIndexVal(0) + + // resolver closing + case <-r.watcherStopChannel: + cancel() + break watchLoop + } + } +} + +// startWatcher will configure the watch plan and start the watch loop in a separate routine +func (r *resolver) startWatcher() { + if !r.watcherStarted.CompareAndSwap(false, true) { + return + } + + options := *r.config.QueryOptions + options.UseCache = false // always ignore consul agent cache for watcher + options.Filter = "" // don't use configured filter for State() calls + + // Configure exponential backoff + ebo := backoff.NewExponentialBackOff() + ebo.InitialInterval = initialBackOffInternal + ebo.MaxInterval = maxBackOffInternal + ebo.MaxElapsedTime = 0 + + plan := &watchPlan{ + options: &options, + healthServiceQueryFilter: r.config.QueryOptions.Filter, + lastResult: make(map[serviceIdentifier]bool), + backOff: ebo, + } + + go r.runWatchLoop(plan) +}