From 3e21eff4b7bb1e59a2a37ff021dc6f154b8f52da Mon Sep 17 00:00:00 2001 From: rodrigozhou Date: Tue, 15 Oct 2024 15:43:16 -0700 Subject: [PATCH 1/2] Add a lock to nexus.Service to project access to operations map --- nexus/operation.go | 47 +++++++++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/nexus/operation.go b/nexus/operation.go index 13cd614..ca8e902 100644 --- a/nexus/operation.go +++ b/nexus/operation.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" ) // NoValue is a marker type for an operations that do not accept any input or return a value (nil). @@ -132,6 +133,8 @@ func (h *syncOperation[I, O]) Start(ctx context.Context, input I, options StartO type Service struct { Name string + // Lock to protect operations map in case new services are added while workflows are running. + lock sync.RWMutex operations map[string]RegisterableOperation } @@ -147,17 +150,15 @@ func NewService(name string) *Service { // Returns an error if duplicate operations were registered with the same name or when trying to register an operation // with no name. // -// Can be called multiple times and is not thread safe. +// Can be called multiple times. func (s *Service) Register(operations ...RegisterableOperation) error { var dups []string for _, op := range operations { if op.Name() == "" { return fmt.Errorf("tried to register an operation with no name") } - if _, found := s.operations[op.Name()]; found { + if !s.addOperation(op) { dups = append(dups, op.Name()) - } else { - s.operations[op.Name()] = op } } if len(dups) > 0 { @@ -166,11 +167,31 @@ func (s *Service) Register(operations ...RegisterableOperation) error { return nil } +// addOperation adds the operation if not found. Returns boolean indicating if the operation was added. +func (s *Service) addOperation(op RegisterableOperation) bool { + s.lock.Lock() + defer s.lock.Unlock() + if _, ok := s.operations[op.Name()]; ok { + return false + } + s.operations[op.Name()] = op + return true +} + // Operation returns an operation by name or nil if not found. func (s *Service) Operation(name string) RegisterableOperation { + s.lock.RLock() + defer s.lock.RUnlock() return s.operations[name] } +// NumOperations returns the number of operations registered. +func (s *Service) NumOperations() int { + s.lock.RLock() + defer s.lock.RUnlock() + return len(s.operations) +} + // A ServiceRegistry registers services and constructs a [Handler] that dispatches operations requests to those services. type ServiceRegistry struct { services map[string]*Service @@ -209,7 +230,7 @@ func (r *ServiceRegistry) NewHandler() (Handler, error) { return nil, errors.New("must register at least one service") } for _, service := range r.services { - if len(service.operations) == 0 { + if service.NumOperations() == 0 { return nil, fmt.Errorf("service %q has no operations registered", service.Name) } } @@ -229,8 +250,8 @@ func (r *registryHandler) CancelOperation(ctx context.Context, service, operatio if !ok { return HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service) } - h, ok := s.operations[operation] - if !ok { + h := s.Operation(operation) + if h == nil { return HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation) } @@ -250,8 +271,8 @@ func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operati if !ok { return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service) } - h, ok := s.operations[operation] - if !ok { + h := s.Operation(operation) + if h == nil { return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation) } @@ -272,8 +293,8 @@ func (r *registryHandler) GetOperationResult(ctx context.Context, service, opera if !ok { return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service) } - h, ok := s.operations[operation] - if !ok { + h := s.Operation(operation) + if h == nil { return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation) } @@ -292,8 +313,8 @@ func (r *registryHandler) StartOperation(ctx context.Context, service, operation if !ok { return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service) } - h, ok := s.operations[operation] - if !ok { + h := s.Operation(operation) + if h == nil { return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation) } From 7715d5af21608a3b83736166e2b2992c70b82890 Mon Sep 17 00:00:00 2001 From: rodrigozhou Date: Wed, 16 Oct 2024 09:13:38 -0700 Subject: [PATCH 2/2] address comments --- nexus/operation.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nexus/operation.go b/nexus/operation.go index ca8e902..0e6b913 100644 --- a/nexus/operation.go +++ b/nexus/operation.go @@ -133,7 +133,7 @@ func (h *syncOperation[I, O]) Start(ctx context.Context, input I, options StartO type Service struct { Name string - // Lock to protect operations map in case new services are added while workflows are running. + // Lock to protect the operations map in case new operations are added while workflows are running. lock sync.RWMutex operations map[string]RegisterableOperation } @@ -167,7 +167,7 @@ func (s *Service) Register(operations ...RegisterableOperation) error { return nil } -// addOperation adds the operation if not found. Returns boolean indicating if the operation was added. +// addOperation adds the operation if not found. Returns a boolean indicating if the operation was added. func (s *Service) addOperation(op RegisterableOperation) bool { s.lock.Lock() defer s.lock.Unlock() @@ -185,8 +185,8 @@ func (s *Service) Operation(name string) RegisterableOperation { return s.operations[name] } -// NumOperations returns the number of operations registered. -func (s *Service) NumOperations() int { +// numOperations returns the number of operations registered. +func (s *Service) numOperations() int { s.lock.RLock() defer s.lock.RUnlock() return len(s.operations) @@ -230,7 +230,7 @@ func (r *ServiceRegistry) NewHandler() (Handler, error) { return nil, errors.New("must register at least one service") } for _, service := range r.services { - if service.NumOperations() == 0 { + if service.numOperations() == 0 { return nil, fmt.Errorf("service %q has no operations registered", service.Name) } }