Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improper use of secretAddedSignal channel #1054

Merged
merged 3 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions internal/app/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ import (
bootstrapContainer "github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/flags"
bootstrapInterfaces "github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/secret"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/startup"
"github.com/edgexfoundry/go-mod-bootstrap/v2/di"

Expand Down Expand Up @@ -195,15 +194,6 @@ func (svc *Service) MakeItRun() error {
// deferred is a function that needs to be called when services exits.
svc.addDeferred(deferred)

if secret.IsSecurityEnabled() {
// add a deferred function to close the SecretAddedSignal channel created during service initialization.
svc.addDeferred(func() {
if secretAddedSignal, ok := svc.ctx.appCtx.Value(internal.ContextKeySecretAddedSignal).(chan struct{}); ok {
close(secretAddedSignal)
}
})
}

if svc.config.Writable.StoreAndForward.Enabled {
svc.startStoreForward()
} else {
Expand Down Expand Up @@ -505,12 +495,6 @@ func (svc *Service) Initialize() error {
svc.ctx.appCtx, svc.ctx.appCancelCtx = context.WithCancel(context.Background())
svc.ctx.appWg = &sync.WaitGroup{}

var secretAddedSignal chan struct{}
if secret.IsSecurityEnabled() {
secretAddedSignal = make(chan struct{}, 1)
svc.ctx.appCtx = context.WithValue(svc.ctx.appCtx, internal.ContextKeySecretAddedSignal, secretAddedSignal)
}

var deferred bootstrap.Deferred
var successful bool
var configUpdated config.UpdatedStream = make(chan struct{})
Expand Down Expand Up @@ -557,7 +541,7 @@ func (svc *Service) Initialize() error {
// to wait to be signaled when the configuration has been updated and then process the changes
NewConfigUpdateProcessor(svc).WaitForConfigUpdates(configUpdated)

svc.webserver = webserver.NewWebServer(svc.dic, mux.NewRouter(), svc.serviceKey, secretAddedSignal)
svc.webserver = webserver.NewWebServer(svc.dic, mux.NewRouter(), svc.serviceKey)
svc.webserver.ConfigureStandardRoutes()

svc.lc.Info("Service started in: " + startupTimer.SinceAsString())
Expand Down
2 changes: 1 addition & 1 deletion internal/app/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func IsInstanceOf(objectPtr, typePtr interface{}) bool {
func TestAddRoute(t *testing.T) {
router := mux.NewRouter()

ws := webserver.NewWebServer(dic, router, uuid.NewString(), nil)
ws := webserver.NewWebServer(dic, router, uuid.NewString())

sdk := Service{
webserver: ws,
Expand Down
4 changes: 4 additions & 0 deletions internal/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ type ExternalMqttConfig struct {
// AuthMode indicates what to use when connecting to the broker. Options are "none", "cacert" , "usernamepassword", "clientcert".
// If a CA Cert exists in the SecretPath then it will be used for all modes except "none".
AuthMode string
// RetryDuration indicates how long (in seconds) to wait timing out on the MQTT client creation
RetryDuration int
// RetryInterval indicates the time (in seconds) that will be waited between attempts to create MQTT client
RetryInterval int
}

// PipelineInfo defines the top level data for configurable pipelines
Expand Down
4 changes: 0 additions & 4 deletions internal/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@ import (
"github.com/edgexfoundry/go-mod-core-contracts/v2/common"
)

type contextKey int

const (
ConfigRegistryStem = "edgex/appservices/"

ApiTriggerRoute = common.ApiBase + "/trigger"
ApiAddSecretRoute = common.ApiBase + "/secret"

ContextKeySecretAddedSignal contextKey = iota
)

// SDKVersion indicates the version of the SDK - will be overwritten by build
Expand Down
28 changes: 11 additions & 17 deletions internal/controller/rest/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,21 @@ import (

// Controller controller for V2 REST APIs
type Controller struct {
router *mux.Router
secretProvider interfaces.SecretProvider
lc logger.LoggingClient
config *sdkCommon.ConfigurationStruct
serviceName string
secretAddedSignal chan struct{}
router *mux.Router
secretProvider interfaces.SecretProvider
lc logger.LoggingClient
config *sdkCommon.ConfigurationStruct
serviceName string
}

// NewController creates and initializes an Controller
func NewController(router *mux.Router, dic *di.Container, serviceName string, secretAddedSignal chan struct{}) *Controller {
func NewController(router *mux.Router, dic *di.Container, serviceName string) *Controller {
return &Controller{
router: router,
secretProvider: bootstrapContainer.SecretProviderFrom(dic.Get),
lc: bootstrapContainer.LoggingClientFrom(dic.Get),
config: container.ConfigurationFrom(dic.Get),
serviceName: serviceName,
secretAddedSignal: secretAddedSignal,
router: router,
secretProvider: bootstrapContainer.SecretProviderFrom(dic.Get),
lc: bootstrapContainer.LoggingClientFrom(dic.Get),
config: container.ConfigurationFrom(dic.Get),
serviceName: serviceName,
}
}

Expand Down Expand Up @@ -122,10 +120,6 @@ func (c *Controller) AddSecret(writer http.ResponseWriter, request *http.Request

response := commonDtos.NewBaseResponse(secretRequest.RequestId, "", http.StatusCreated)
c.sendResponse(writer, request, internal.ApiAddSecretRoute, response, http.StatusCreated)

if c.secretAddedSignal != nil {
c.secretAddedSignal <- struct{}{}
}
}

func (c *Controller) sendError(
Expand Down
13 changes: 5 additions & 8 deletions internal/controller/rest/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestMain(m *testing.M) {
func TestPingRequest(t *testing.T) {
serviceName := uuid.NewString()

target := NewController(nil, dic, serviceName, nil)
target := NewController(nil, dic, serviceName)

recorder := doRequest(t, http.MethodGet, common.ApiPingRoute, target.Ping, nil)

Expand All @@ -83,7 +83,7 @@ func TestVersionRequest(t *testing.T) {
internal.ApplicationVersion = expectedAppVersion
internal.SDKVersion = expectedSdkVersion

target := NewController(nil, dic, serviceName, nil)
target := NewController(nil, dic, serviceName)

recorder := doRequest(t, http.MethodGet, common.ApiVersion, target.Version, nil)

Expand All @@ -100,7 +100,7 @@ func TestVersionRequest(t *testing.T) {
func TestMetricsRequest(t *testing.T) {
serviceName := uuid.NewString()

target := NewController(nil, dic, serviceName, nil)
target := NewController(nil, dic, serviceName)

recorder := doRequest(t, http.MethodGet, common.ApiMetricsRoute, target.Metrics, nil)

Expand Down Expand Up @@ -140,7 +140,7 @@ func TestConfigRequest(t *testing.T) {
},
})

target := NewController(nil, dic, serviceName, nil)
target := NewController(nil, dic, serviceName)

recorder := doRequest(t, http.MethodGet, common.ApiConfigRoute, target.Config, nil)

Expand Down Expand Up @@ -176,10 +176,7 @@ func TestAddSecretRequest(t *testing.T) {
mockProvider.On("StoreSecrets", "/mqtt", map[string]string{"password": "password", "username": "username"}).Return(nil)
mockProvider.On("StoreSecrets", "/no", map[string]string{"password": "password", "username": "username"}).Return(errors.New("Invalid w/o Vault"))

ch := make(chan struct{}, 1)
defer close(ch)

target := NewController(nil, dic, uuid.NewString(), ch)
target := NewController(nil, dic, uuid.NewString())
assert.NotNil(t, target)

validRequest := commonDtos.SecretRequest{
Expand Down
78 changes: 51 additions & 27 deletions internal/trigger/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,28 @@ import (
"sync"
"time"

"github.com/edgexfoundry/app-functions-sdk-go/v2/internal"
"github.com/edgexfoundry/app-functions-sdk-go/v2/internal/common"
"github.com/edgexfoundry/app-functions-sdk-go/v2/internal/trigger"
"github.com/edgexfoundry/app-functions-sdk-go/v2/pkg/interfaces"
"github.com/edgexfoundry/app-functions-sdk-go/v2/pkg/secure"
"github.com/edgexfoundry/app-functions-sdk-go/v2/pkg/util"

"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/secret"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/messaging"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/startup"
"github.com/edgexfoundry/go-mod-core-contracts/v2/clients/logger"
commonContracts "github.com/edgexfoundry/go-mod-core-contracts/v2/common"
"github.com/edgexfoundry/go-mod-messaging/v2/pkg/types"

pahoMqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/google/uuid"
)

const (
defaultRetryDuration = 600
defaultRetryInterval = 5
)

// Trigger implements Trigger to support Triggers
type Trigger struct {
messageProcessor trigger.MessageProcessor
Expand All @@ -61,7 +68,7 @@ func NewTrigger(bnd trigger.ServiceBinding, mp trigger.MessageProcessor) *Trigge
}

// Initialize initializes the Trigger for an external MQTT broker
func (trigger *Trigger) Initialize(_ *sync.WaitGroup, ctx context.Context, background <-chan interfaces.BackgroundMessage) (bootstrap.Deferred, error) {
func (trigger *Trigger) Initialize(_ *sync.WaitGroup, _ context.Context, background <-chan interfaces.BackgroundMessage) (bootstrap.Deferred, error) {
// Convenience short cuts
lc := trigger.serviceBinding.LoggingClient()
config := trigger.serviceBinding.Config()
Expand Down Expand Up @@ -102,35 +109,28 @@ func (trigger *Trigger) Initialize(_ *sync.WaitGroup, ctx context.Context, backg
opts.KeepAlive = brokerConfig.KeepAlive
opts.Servers = []*url.URL{brokerUrl}

var secretAddedSignal chan struct{}
var ok bool
if secret.IsSecurityEnabled() {
if secretAddedSignal, ok = ctx.Value(internal.ContextKeySecretAddedSignal).(chan struct{}); !ok {
return nil, errors.New("the SecretAddedSignal channel cannot be found in the context")
}
if brokerConfig.RetryDuration <= 0 {
brokerConfig.RetryDuration = defaultRetryDuration
}

mqttFactory := secure.NewMqttFactory(
trigger.serviceBinding.SecretProvider(),
trigger.serviceBinding.LoggingClient(),
brokerConfig.AuthMode,
brokerConfig.SecretPath,
brokerConfig.SkipCertVerify,
secretAddedSignal,
)

mqttClient, err := mqttFactory.Create(opts)
if err != nil {
return nil, fmt.Errorf("unable to create secure MQTT Client: %s", err.Error())
if brokerConfig.RetryInterval <= 0 {
brokerConfig.RetryInterval = defaultRetryInterval
}

lc.Infof("Connecting to mqtt broker for MQTT trigger at: %s", brokerUrl)

if token := mqttClient.Connect(); token.Wait() && token.Error() != nil {
return nil, fmt.Errorf("could not connect to broker for MQTT trigger: %s", token.Error().Error())
sp := trigger.serviceBinding.SecretProvider()
var mqttClient pahoMqtt.Client
timer := startup.NewTimer(brokerConfig.RetryDuration, brokerConfig.RetryInterval)
for timer.HasNotElapsed() {
if mqttClient, err = createMqttClient(sp, lc, brokerConfig, opts); err != nil {
lc.Warnf("%s. Attempt to create MQTT client again after %d seconds...", err.Error(), brokerConfig.RetryInterval)
timer.SleepForInterval()
continue
}
break
}

lc.Info("Connected to mqtt server for MQTT trigger")
if err != nil {
return nil, fmt.Errorf("unable to create MQTT Client: %s", err.Error())
}

deferred := func() {
lc.Info("Disconnecting from broker for MQTT trigger")
Expand Down Expand Up @@ -227,3 +227,27 @@ func (trigger *Trigger) responseHandler(appContext interfaces.AppFunctionContext
}
return nil
}

func createMqttClient(sp messaging.SecretDataProvider, lc logger.LoggingClient, config common.ExternalMqttConfig,
opts *pahoMqtt.ClientOptions) (pahoMqtt.Client, error) {
mqttFactory := secure.NewMqttFactory(
sp,
lc,
config.AuthMode,
config.SecretPath,
config.SkipCertVerify,
)
mqttClient, err := mqttFactory.Create(opts)
if err != nil {
return nil, fmt.Errorf("unable to create secure MQTT Client: %s", err.Error())
}

lc.Infof("Connecting to mqtt broker for MQTT trigger at: %s", config.Url)

if token := mqttClient.Connect(); token.Wait() && token.Error() != nil {
return nil, fmt.Errorf("could not connect to broker for MQTT trigger: %s", token.Error().Error())
}

lc.Info("Connected to mqtt server for MQTT trigger")
return mqttClient, nil
}
4 changes: 2 additions & 2 deletions internal/webserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ type Version struct {
}

// NewWebServer returns a new instance of *WebServer
func NewWebServer(dic *di.Container, router *mux.Router, serviceName string, secretAddedSignal chan struct{}) *WebServer {
func NewWebServer(dic *di.Container, router *mux.Router, serviceName string) *WebServer {
ws := &WebServer{
lc: bootstrapContainer.LoggingClientFrom(dic.Get),
config: container.ConfigurationFrom(dic.Get),
router: router,
controller: rest.NewController(router, dic, serviceName, secretAddedSignal),
controller: rest.NewController(router, dic, serviceName),
}

return ws
Expand Down
4 changes: 2 additions & 2 deletions internal/webserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestAddRoute(t *testing.T) {
routePath := "/testRoute"
testHandler := func(_ http.ResponseWriter, _ *http.Request) {}

webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString(), nil)
webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString())
err := webserver.AddRoute(routePath, testHandler)
assert.NoError(t, err, "Not expecting an error")

Expand All @@ -69,7 +69,7 @@ func TestAddRoute(t *testing.T) {
}

func TestSetupTriggerRoute(t *testing.T) {
webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString(), nil)
webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString())

handlerFunctionNotCalled := true
handler := func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading