Skip to content

Commit

Permalink
feat: Improve service initialization process
Browse files Browse the repository at this point in the history
- Start the Web Server before the trigger initialization.
- Add retry mechanism for MqttFactory.Create()

Signed-off-by: Felix Ting <[email protected]>
  • Loading branch information
FelixTing committed Mar 3, 2022
1 parent 73b9dee commit 7929a1c
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 63 deletions.
28 changes: 22 additions & 6 deletions internal/app/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
bootstrapContainer "github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/flags"
bootstrapInterfaces "github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/secret"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/startup"
"github.com/edgexfoundry/go-mod-bootstrap/v2/di"

Expand Down Expand Up @@ -163,6 +164,11 @@ func (svc *Service) MakeItRun() error {

svc.ctx.stop = stop

httpErrors := make(chan error)
defer close(httpErrors)

svc.webserver.StartWebServer(httpErrors)

// determine input type and create trigger for it
t := svc.setupTrigger(svc.config)
if t == nil {
Expand All @@ -179,6 +185,15 @@ func (svc *Service) MakeItRun() error {
// deferred is a function that needs to be called when services exits.
svc.addDeferred(deferred)

if secret.IsSecurityEnabled() {
// add a deferred function to close the SecretAddedSignal channel created during service initialization.
svc.addDeferred(func() {
if secretAddedSignal, ok := svc.ctx.appCtx.Value(internal.ContextKeySecretAddedSignal).(chan struct{}); ok {
close(secretAddedSignal)
}
})
}

if svc.config.Writable.StoreAndForward.Enabled {
svc.startStoreForward()
} else {
Expand All @@ -190,11 +205,6 @@ func (svc *Service) MakeItRun() error {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)

httpErrors := make(chan error)
defer close(httpErrors)

svc.webserver.StartWebServer(httpErrors)

select {
case httpError := <-httpErrors:
svc.lc.Info("Http error received: ", httpError.Error())
Expand Down Expand Up @@ -484,6 +494,12 @@ func (svc *Service) Initialize() error {
svc.ctx.appCtx, svc.ctx.appCancelCtx = context.WithCancel(context.Background())
svc.ctx.appWg = &sync.WaitGroup{}

var secretAddedSignal chan struct{}
if secret.IsSecurityEnabled() {
secretAddedSignal = make(chan struct{}, 1)
svc.ctx.appCtx = context.WithValue(svc.ctx.appCtx, internal.ContextKeySecretAddedSignal, secretAddedSignal)
}

var deferred bootstrap.Deferred
var successful bool
var configUpdated config.UpdatedStream = make(chan struct{})
Expand Down Expand Up @@ -531,7 +547,7 @@ func (svc *Service) Initialize() error {
// to wait to be signaled when the configuration has been updated and then process the changes
NewConfigUpdateProcessor(svc).WaitForConfigUpdates(configUpdated)

svc.webserver = webserver.NewWebServer(svc.dic, mux.NewRouter(), svc.serviceKey)
svc.webserver = webserver.NewWebServer(svc.dic, mux.NewRouter(), svc.serviceKey, secretAddedSignal)
svc.webserver.ConfigureStandardRoutes()

svc.lc.Info("Service started in: " + startupTimer.SinceAsString())
Expand Down
2 changes: 1 addition & 1 deletion internal/app/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func IsInstanceOf(objectPtr, typePtr interface{}) bool {
func TestAddRoute(t *testing.T) {
router := mux.NewRouter()

ws := webserver.NewWebServer(dic, router, uuid.NewString())
ws := webserver.NewWebServer(dic, router, uuid.NewString(), nil)

sdk := Service{
webserver: ws,
Expand Down
4 changes: 4 additions & 0 deletions internal/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ import (
"github.com/edgexfoundry/go-mod-core-contracts/v2/common"
)

type contextKey int

const (
ConfigRegistryStem = "edgex/appservices/"

ApiTriggerRoute = common.ApiBase + "/trigger"
ApiAddSecretRoute = common.ApiBase + "/secret"

ContextKeySecretAddedSignal contextKey = iota
)

// SDKVersion indicates the version of the SDK - will be overwritten by build
Expand Down
28 changes: 17 additions & 11 deletions internal/controller/rest/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ import (

// Controller controller for V2 REST APIs
type Controller struct {
router *mux.Router
secretProvider interfaces.SecretProvider
lc logger.LoggingClient
config *sdkCommon.ConfigurationStruct
serviceName string
router *mux.Router
secretProvider interfaces.SecretProvider
lc logger.LoggingClient
config *sdkCommon.ConfigurationStruct
serviceName string
secretAddedSignal chan struct{}
}

// NewController creates and initializes an Controller
func NewController(router *mux.Router, dic *di.Container, serviceName string) *Controller {
func NewController(router *mux.Router, dic *di.Container, serviceName string, secretAddedSignal chan struct{}) *Controller {
return &Controller{
router: router,
secretProvider: bootstrapContainer.SecretProviderFrom(dic.Get),
lc: bootstrapContainer.LoggingClientFrom(dic.Get),
config: container.ConfigurationFrom(dic.Get),
serviceName: serviceName,
router: router,
secretProvider: bootstrapContainer.SecretProviderFrom(dic.Get),
lc: bootstrapContainer.LoggingClientFrom(dic.Get),
config: container.ConfigurationFrom(dic.Get),
serviceName: serviceName,
secretAddedSignal: secretAddedSignal,
}
}

Expand Down Expand Up @@ -120,6 +122,10 @@ func (c *Controller) AddSecret(writer http.ResponseWriter, request *http.Request

response := commonDtos.NewBaseResponse(secretRequest.RequestId, "", http.StatusCreated)
c.sendResponse(writer, request, internal.ApiAddSecretRoute, response, http.StatusCreated)

if c.secretAddedSignal != nil {
c.secretAddedSignal <- struct{}{}
}
}

func (c *Controller) sendError(
Expand Down
13 changes: 8 additions & 5 deletions internal/controller/rest/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestMain(m *testing.M) {
func TestPingRequest(t *testing.T) {
serviceName := uuid.NewString()

target := NewController(nil, dic, serviceName)
target := NewController(nil, dic, serviceName, nil)

recorder := doRequest(t, http.MethodGet, common.ApiPingRoute, target.Ping, nil)

Expand All @@ -83,7 +83,7 @@ func TestVersionRequest(t *testing.T) {
internal.ApplicationVersion = expectedAppVersion
internal.SDKVersion = expectedSdkVersion

target := NewController(nil, dic, serviceName)
target := NewController(nil, dic, serviceName, nil)

recorder := doRequest(t, http.MethodGet, common.ApiVersion, target.Version, nil)

Expand All @@ -100,7 +100,7 @@ func TestVersionRequest(t *testing.T) {
func TestMetricsRequest(t *testing.T) {
serviceName := uuid.NewString()

target := NewController(nil, dic, serviceName)
target := NewController(nil, dic, serviceName, nil)

recorder := doRequest(t, http.MethodGet, common.ApiMetricsRoute, target.Metrics, nil)

Expand Down Expand Up @@ -140,7 +140,7 @@ func TestConfigRequest(t *testing.T) {
},
})

target := NewController(nil, dic, serviceName)
target := NewController(nil, dic, serviceName, nil)

recorder := doRequest(t, http.MethodGet, common.ApiConfigRoute, target.Config, nil)

Expand Down Expand Up @@ -176,7 +176,10 @@ func TestAddSecretRequest(t *testing.T) {
mockProvider.On("StoreSecrets", "/mqtt", map[string]string{"password": "password", "username": "username"}).Return(nil)
mockProvider.On("StoreSecrets", "/no", map[string]string{"password": "password", "username": "username"}).Return(errors.New("Invalid w/o Vault"))

target := NewController(nil, dic, uuid.NewString())
ch := make(chan struct{}, 1)
defer close(ch)

target := NewController(nil, dic, uuid.NewString(), ch)
assert.NotNil(t, target)

validRequest := commonDtos.SecretRequest{
Expand Down
15 changes: 13 additions & 2 deletions internal/trigger/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@ import (
"context"
"errors"
"fmt"
"github.com/edgexfoundry/app-functions-sdk-go/v2/internal/trigger"
"net/url"
"strings"
"sync"
"time"

"github.com/edgexfoundry/app-functions-sdk-go/v2/internal"
"github.com/edgexfoundry/app-functions-sdk-go/v2/internal/trigger"
"github.com/edgexfoundry/app-functions-sdk-go/v2/pkg/interfaces"
"github.com/edgexfoundry/app-functions-sdk-go/v2/pkg/secure"
"github.com/edgexfoundry/app-functions-sdk-go/v2/pkg/util"

"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/secret"
commonContracts "github.com/edgexfoundry/go-mod-core-contracts/v2/common"
"github.com/edgexfoundry/go-mod-messaging/v2/pkg/types"

Expand Down Expand Up @@ -59,7 +61,7 @@ func NewTrigger(bnd trigger.ServiceBinding, mp trigger.MessageProcessor) *Trigge
}

// Initialize initializes the Trigger for an external MQTT broker
func (trigger *Trigger) Initialize(_ *sync.WaitGroup, _ context.Context, background <-chan interfaces.BackgroundMessage) (bootstrap.Deferred, error) {
func (trigger *Trigger) Initialize(_ *sync.WaitGroup, ctx context.Context, background <-chan interfaces.BackgroundMessage) (bootstrap.Deferred, error) {
// Convenience short cuts
lc := trigger.serviceBinding.LoggingClient()
config := trigger.serviceBinding.Config()
Expand Down Expand Up @@ -100,12 +102,21 @@ func (trigger *Trigger) Initialize(_ *sync.WaitGroup, _ context.Context, backgro
opts.KeepAlive = brokerConfig.KeepAlive
opts.Servers = []*url.URL{brokerUrl}

var secretAddedSignal chan struct{}
var ok bool
if secret.IsSecurityEnabled() {
if secretAddedSignal, ok = ctx.Value(internal.ContextKeySecretAddedSignal).(chan struct{}); !ok {
return nil, errors.New("the SecretAddedSignal channel cannot be found in the context")
}
}

mqttFactory := secure.NewMqttFactory(
trigger.serviceBinding.SecretProvider(),
trigger.serviceBinding.LoggingClient(),
brokerConfig.AuthMode,
brokerConfig.SecretPath,
brokerConfig.SkipCertVerify,
secretAddedSignal,
)

mqttClient, err := mqttFactory.Create(opts)
Expand Down
4 changes: 2 additions & 2 deletions internal/webserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ type Version struct {
}

// NewWebServer returns a new instance of *WebServer
func NewWebServer(dic *di.Container, router *mux.Router, serviceName string) *WebServer {
func NewWebServer(dic *di.Container, router *mux.Router, serviceName string, secretAddedSignal chan struct{}) *WebServer {
ws := &WebServer{
lc: bootstrapContainer.LoggingClientFrom(dic.Get),
config: container.ConfigurationFrom(dic.Get),
router: router,
controller: rest.NewController(router, dic, serviceName),
controller: rest.NewController(router, dic, serviceName, secretAddedSignal),
}

return ws
Expand Down
4 changes: 2 additions & 2 deletions internal/webserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestAddRoute(t *testing.T) {
routePath := "/testRoute"
testHandler := func(_ http.ResponseWriter, _ *http.Request) {}

webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString())
webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString(), nil)
err := webserver.AddRoute(routePath, testHandler)
assert.NoError(t, err, "Not expecting an error")

Expand All @@ -69,7 +69,7 @@ func TestAddRoute(t *testing.T) {
}

func TestSetupTriggerRoute(t *testing.T) {
webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString())
webserver := NewWebServer(dic, mux.NewRouter(), uuid.NewString(), nil)

handlerFunctionNotCalled := true
handler := func(w http.ResponseWriter, r *http.Request) {
Expand Down
85 changes: 60 additions & 25 deletions pkg/secure/mqttfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,34 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"github.com/eclipse/paho.mqtt.golang"
"fmt"

"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/messaging"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/secret"
"github.com/edgexfoundry/go-mod-core-contracts/v2/clients/logger"

"github.com/eclipse/paho.mqtt.golang"
)

type MqttFactory struct {
sp messaging.SecretDataProvider
logger logger.LoggingClient
authMode string
secretPath string
opts *mqtt.ClientOptions
skipCertVerify bool
sp messaging.SecretDataProvider
logger logger.LoggingClient
authMode string
secretPath string
opts *mqtt.ClientOptions
skipCertVerify bool
secretAddedSignal chan struct{}
}

func NewMqttFactory(sp messaging.SecretDataProvider, log logger.LoggingClient, mode string, path string, skipVerify bool) MqttFactory {
func NewMqttFactory(sp messaging.SecretDataProvider, log logger.LoggingClient, mode string, path string, skipVerify bool,
secretAddedSignal chan struct{}) MqttFactory {
return MqttFactory{
sp: sp,
logger: log,
authMode: mode,
secretPath: path,
skipCertVerify: skipVerify,
sp: sp,
logger: log,
authMode: mode,
secretPath: path,
skipCertVerify: skipVerify,
secretAddedSignal: secretAddedSignal,
}
}

Expand All @@ -52,24 +59,34 @@ func (factory MqttFactory) Create(opts *mqtt.ClientOptions) (mqtt.Client, error)

factory.opts = opts

//get the secrets from the secret provider and populate the struct
secretData, err := messaging.GetSecretData(factory.authMode, factory.secretPath, factory.sp)
if err != nil {
return nil, err
}
//ensure that the authmode selected has the required secret values
if secretData != nil {
err = messaging.ValidateSecretData(factory.authMode, factory.secretPath, secretData)
if err != nil {
return nil, err
secretData, err := factory.getValidSecretData()
switch secret.IsSecurityEnabled() {
case true:
if err == nil {
break
}
factory.logger.Error(err.Error())
for {
factory.logger.Info("Waiting for the secret creation API call to seed the proper credentials...")
<-factory.secretAddedSignal
secretData, err = factory.getValidSecretData()
if err != nil {
factory.logger.Error(err.Error())
} else {
break
}
}
// configure the mqtt client with the retrieved secret values
err = factory.configureMQTTClientForAuth(secretData)
case false:
if err != nil {
return nil, err
}
}

err = factory.configureMQTTClientForAuth(secretData)
if err != nil {
return nil, err
}

return mqtt.NewClient(factory.opts), nil
}

Expand Down Expand Up @@ -110,3 +127,21 @@ func (factory MqttFactory) configureMQTTClientForAuth(secretData *messaging.Secr

return nil
}

func (factory MqttFactory) getValidSecretData() (*messaging.SecretData, error) {
//get the secrets from the secret provider and populate the struct
secretData, err := messaging.GetSecretData(factory.authMode, factory.secretPath, factory.sp)
if err != nil {
return nil, fmt.Errorf("failed to get secret data from the secret provider, error: %s", err)
}
if secretData == nil {
return nil, nil
}
//ensure that the authmode selected has the required secret values
err = messaging.ValidateSecretData(factory.authMode, factory.secretPath, secretData)
if err != nil {
return nil, fmt.Errorf("invalid secret data, error: %s", err)
} else {
return secretData, nil
}
}
Loading

0 comments on commit 7929a1c

Please sign in to comment.