From f1047d1cb40a7f42359f0a017870aec277ad0d3b Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Wed, 16 Mar 2022 13:52:19 -0700 Subject: [PATCH 1/8] Context logging --- cmd/nanomdm/main.go | 25 ++++++++++++++- http/api.go | 4 +++ http/mdm.go | 3 ++ http/mdm_cert.go | 5 +++ log/ctxlog/ctxlog.go | 62 ++++++++++++++++++++++++++++++++++++ push/service/service.go | 15 +++++++-- service/certauth/certauth.go | 15 +++++---- service/multi/multi.go | 21 ++++++------ service/nanomdm/ctxlog.go | 38 ++++++++++++++++++++++ service/nanomdm/service.go | 50 +++++++++++------------------ storage/allmulti/allmulti.go | 18 +++++++---- storage/allmulti/bstoken.go | 4 +-- storage/allmulti/certauth.go | 8 ++--- storage/allmulti/push.go | 2 +- storage/allmulti/pushcert.go | 6 ++-- storage/allmulti/queue.go | 8 ++--- storage/mysql/mysql.go | 5 ++- 17 files changed, 216 insertions(+), 73 deletions(-) create mode 100644 log/ctxlog/ctxlog.go create mode 100644 service/nanomdm/ctxlog.go diff --git a/cmd/nanomdm/main.go b/cmd/nanomdm/main.go index 3dd0504..3fd7b09 100644 --- a/cmd/nanomdm/main.go +++ b/cmd/nanomdm/main.go @@ -1,20 +1,24 @@ package main import ( + "context" "crypto/subtle" "crypto/x509" "flag" "fmt" "io/ioutil" stdlog "log" + "math/rand" "net" "net/http" "os" + "time" "github.com/micromdm/nanomdm/certverify" "github.com/micromdm/nanomdm/cmd/cli" mdmhttp "github.com/micromdm/nanomdm/http" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/log/stdlogfmt" "github.com/micromdm/nanomdm/push/buford" pushsvc "github.com/micromdm/nanomdm/push/service" @@ -196,6 +200,8 @@ func main() { w.Write([]byte(`{"version":"` + version + `"}`)) }) + rand.Seed(time.Now().UnixNano()) + logger.Info("msg", "starting server", "listen", *flListen) err = http.ListenAndServe(*flListen, simpleLog(mux, logger.With("handler", "log"))) logs := []interface{}{"msg", "server shutdown"} @@ -219,8 +225,25 @@ func basicAuth(next http.Handler, username, password, realm string) http.Handler } } +type ctxKeyTraceID struct{} + +// storeNewTraceID generates a new trace identifier and stores it on +// the context. +func storeNewTraceID(ctx context.Context) context.Context { + // currently this just makes a random string. this would be better + // served by e.g. https://github.com/oklog/ulid or something like + // https://opentelemetry.io/ someday. + b := make([]byte, 8) + rand.Read(b) + id := fmt.Sprintf("%x", b) + return context.WithValue(ctx, ctxKeyTraceID{}, id) +} + func simpleLog(next http.Handler, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + ctx := storeNewTraceID(r.Context()) + ctx = ctxlog.AddFunc(ctx, ctxlog.SimpleStringFunc(ctxKeyTraceID{}, "trace_id")) + logger := ctxlog.Logger(ctx, logger) host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr @@ -235,6 +258,6 @@ func simpleLog(next http.Handler, logger log.Logger) http.HandlerFunc { logs = append(logs, "real_ip", fwdedFor) } logger.Info(logs...) - next.ServeHTTP(w, r) + next.ServeHTTP(w, r.WithContext(ctx)) } } diff --git a/http/api.go b/http/api.go index 79e71bb..e0fa3f7 100644 --- a/http/api.go +++ b/http/api.go @@ -12,6 +12,7 @@ import ( "github.com/micromdm/nanomdm/cryptoutil" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/push" "github.com/micromdm/nanomdm/storage" @@ -45,6 +46,7 @@ type apiResult struct { // users. func PushHandlerFunc(pusher push.Pusher, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) ids := strings.Split(r.URL.Path, ",") output := apiResult{ Status: make(enrolledAPIResults), @@ -88,6 +90,7 @@ func PushHandlerFunc(pusher push.Pusher, logger log.Logger) http.HandlerFunc { // for "API" users. func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Pusher, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) b, err := ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) @@ -169,6 +172,7 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push // upload our push certs. func StorePushCertHandlerFunc(storage storage.PushCertStore, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) b, err := ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) diff --git a/http/mdm.go b/http/mdm.go index 0ef8980..e8b6c7b 100644 --- a/http/mdm.go +++ b/http/mdm.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/service" ) @@ -26,6 +27,7 @@ func mdmReqFromHTTPReq(r *http.Request) *mdm.Request { // CheckinHandlerFunc decodes an MDM check-in request and adapts it to service. func CheckinHandlerFunc(svc service.Checkin, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) bodyBytes, err := ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) @@ -49,6 +51,7 @@ func CheckinHandlerFunc(svc service.Checkin, logger log.Logger) http.HandlerFunc // CommandAndReportResultsHandlerFunc decodes an MDM command request and adapts it to service. func CommandAndReportResultsHandlerFunc(svc service.CommandAndReportResults, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) bodyBytes, err := ReadAllAndReplaceBody(r) if err != nil { logger.Info("msg", "reading body", "err", err) diff --git a/http/mdm_cert.go b/http/mdm_cert.go index 0297879..1e92339 100644 --- a/http/mdm_cert.go +++ b/http/mdm_cert.go @@ -8,6 +8,7 @@ import ( "github.com/micromdm/nanomdm/cryptoutil" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" ) type contextKeyCert struct{} @@ -21,6 +22,7 @@ type contextKeyCert struct{} // similar header could be used, of course. func CertExtractPEMHeaderMiddleware(next http.Handler, header string, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) escapedCert := r.Header.Get(header) if escapedCert == "" { logger.Debug("msg", "empty header", "header", header) @@ -49,6 +51,7 @@ func CertExtractPEMHeaderMiddleware(next http.Handler, header string, logger log // at the TLS peer certificate in the request. func CertExtractTLSMiddleware(next http.Handler, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) if r.TLS == nil || len(r.TLS.PeerCertificates) < 1 { logger.Debug("msg", "no TLS peer certificate") next.ServeHTTP(w, r) @@ -69,6 +72,7 @@ func CertExtractTLSMiddleware(next http.Handler, logger log.Logger) http.Handler // verification fails. func CertExtractMdmSignatureMiddleware(next http.Handler, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) mdmSig := r.Header.Get("Mdm-Signature") if mdmSig == "" { logger.Debug("msg", "empty Mdm-Signature header") @@ -111,6 +115,7 @@ type CertVerifier interface { // MDM unenrollments in the case of bugs or something going wrong. func CertVerifyMiddleware(next http.Handler, verifier CertVerifier, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + logger := ctxlog.Logger(r.Context(), logger) if err := verifier.Verify(GetCert(r.Context())); err != nil { logger.Info("msg", "error verifying MDM certificate", "err", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) diff --git a/log/ctxlog/ctxlog.go b/log/ctxlog/ctxlog.go new file mode 100644 index 0000000..e73e523 --- /dev/null +++ b/log/ctxlog/ctxlog.go @@ -0,0 +1,62 @@ +// Package ctxlog allows logging data stored with a context. +package ctxlog + +import ( + "context" + "sync" + + "github.com/micromdm/nanomdm/log" +) + +// CtxKVFunc creates logger key-value pairs from a context. +type CtxKVFunc func(context.Context) []interface{} + +// ctxKeyFuncs is the context key for storing and retriveing +// a funcs{} struct on a context. +type ctxKeyFuncs struct{} + +// funcs holds the associated CtxKVFunc functions to run. +type funcs struct { + sync.RWMutex + funcs []CtxKVFunc +} + +// AddFunc associates a new CtxKVFunc function to a context. +func AddFunc(ctx context.Context, f CtxKVFunc) context.Context { + ctxFuncs, ok := ctx.Value(ctxKeyFuncs{}).(*funcs) + if !ok || ctxFuncs == nil { + ctxFuncs = &funcs{} + } + ctxFuncs.Lock() + ctxFuncs.funcs = append(ctxFuncs.funcs, f) + ctxFuncs.Unlock() + return context.WithValue(ctx, ctxKeyFuncs{}, ctxFuncs) +} + +// Logger runs the associated CtxKVFunc functions and returns a new +// logger with the results. +func Logger(ctx context.Context, logger log.Logger) log.Logger { + ctxFuncs, ok := ctx.Value(ctxKeyFuncs{}).(*funcs) + if !ok { + return logger + } + var acc []interface{} + ctxFuncs.RLock() + for _, f := range ctxFuncs.funcs { + acc = append(acc, f(ctx)...) + } + ctxFuncs.RUnlock() + return logger.With(acc...) +} + +// SimpleStringFunc is a helper that generates a simple CtxKVFunc that +// returns a key-value pair if found on the context. +func SimpleStringFunc(ctxKey interface{}, logKey string) CtxKVFunc { + return func(ctx context.Context) (out []interface{}) { + v, _ := ctx.Value(ctxKey).(string) + if v != "" { + out = append(out, logKey, v) + } + return + } +} diff --git a/push/service/service.go b/push/service/service.go index 282f2e3..161cfd9 100644 --- a/push/service/service.go +++ b/push/service/service.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/push" "github.com/micromdm/nanomdm/storage" @@ -64,7 +65,10 @@ func (s *PushService) getProvider(ctx context.Context, topic string) (push.PushP if err != nil { return nil, fmt.Errorf("retrieving push cert for topic %q: %w", topic, err) } - s.logger.Info("msg", "retrieved push cert", "topic", topic) + ctxlog.Logger(ctx, s.logger).Info( + "msg", "retrieved push cert", + "topic", topic, + ) newProvider, err := s.providerFactory.NewPushProvider(cert) if err != nil { return nil, fmt.Errorf("creating new push provider: %w", err) @@ -114,7 +118,10 @@ func (s *PushService) pushMulti(ctx context.Context, pushInfos []*mdm.Push) (map for topic, pushInfos := range topicToPushInfos { prov, err := s.getProvider(ctx, topic) if err != nil { - s.logger.Info("msg", "get provider", "err", err) + ctxlog.Logger(ctx, s.logger).Info( + "msg", "get provider", + "err", err, + ) finalErr = err continue } @@ -191,7 +198,9 @@ func (s *PushService) Push(ctx context.Context, ids []string) (map[string]*push. for token, resp := range tokenToResponse { id, ok := tokenToId[token] if !ok { - s.logger.Info("msg", "could not find id by token") + ctxlog.Logger(ctx, s.logger).Info( + "msg", "could not find id by token", + ) continue } idToResponse[id] = resp diff --git a/service/certauth/certauth.go b/service/certauth/certauth.go index 996e792..fda00b0 100644 --- a/service/certauth/certauth.go +++ b/service/certauth/certauth.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/service" "github.com/micromdm/nanomdm/storage" @@ -110,6 +111,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { if err := r.EnrollID.Validate(); err != nil { return err } + logger := ctxlog.Logger(r.Context, s.logger) hash := hashCert(r.Certificate) if hasHash, err := s.storage.HasCertHash(r, hash); err != nil { return err @@ -124,7 +126,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { } else if isAssoc { return nil } - s.logger.Info( + logger.Info( "msg", "cert hash exists", "enrollment", "new", "id", r.ID, @@ -138,7 +140,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { if err := s.storage.AssociateCertHash(r, hash); err != nil { return err } - s.logger.Info( + logger.Info( "msg", "cert associated", "enrollment", "new", "id", r.ID, @@ -154,6 +156,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { if err := r.EnrollID.Validate(); err != nil { return err } + logger := ctxlog.Logger(r.Context, s.logger) hash := hashCert(r.Certificate) if isAssoc, err := s.storage.IsCertHashAssociated(r, hash); err != nil { return err @@ -161,7 +164,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { return nil } if !s.allowRetroactive { - s.logger.Info( + logger.Info( "msg", "no cert association", "enrollment", "existing", "id", r.ID, @@ -178,7 +181,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { if hasHash, err := s.storage.EnrollmentHasCertHash(r, hash); err != nil { return err } else if hasHash { - s.logger.Info( + logger.Info( "msg", "enrollment cannot have associated cert hash", "enrollment", "existing", "id", r.ID, @@ -195,7 +198,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { if hasHash, err := s.storage.HasCertHash(r, hash); err != nil { return err } else if hasHash { - s.logger.Info( + logger.Info( "msg", "cert hash exists", "enrollment", "existing", "id", r.ID, @@ -211,7 +214,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { if err := s.storage.AssociateCertHash(r, hash); err != nil { return err } - s.logger.Info( + logger.Info( "msg", "cert associated", "enrollment", "existing", "id", r.ID, diff --git a/service/multi/multi.go b/service/multi/multi.go index 521f7fc..17df36b 100644 --- a/service/multi/multi.go +++ b/service/multi/multi.go @@ -5,6 +5,7 @@ import ( "context" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/service" ) @@ -33,12 +34,12 @@ func New(logger log.Logger, svcs ...service.CheckinAndCommandService) *MultiServ type errorRunner func(service.CheckinAndCommandService) error -func (ms *MultiService) runOthers(r errorRunner) { +func (ms *MultiService) runOthers(ctx context.Context, r errorRunner) { for i, svc := range ms.svcs[1:] { go func(n int, s service.CheckinAndCommandService) { err := r(s) if err != nil { - ms.logger.Info( + ctxlog.Logger(ctx, ms.logger).Info( "sub_service", n, "err", err, ) @@ -57,7 +58,7 @@ func (ms *MultiService) RequestWithContext(r *mdm.Request) *mdm.Request { func (ms *MultiService) Authenticate(r *mdm.Request, m *mdm.Authenticate) error { err := ms.svcs[0].Authenticate(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { return svc.Authenticate(rc, m) }) return err @@ -66,7 +67,7 @@ func (ms *MultiService) Authenticate(r *mdm.Request, m *mdm.Authenticate) error func (ms *MultiService) TokenUpdate(r *mdm.Request, m *mdm.TokenUpdate) error { err := ms.svcs[0].TokenUpdate(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { return svc.TokenUpdate(rc, m) }) return err @@ -75,7 +76,7 @@ func (ms *MultiService) TokenUpdate(r *mdm.Request, m *mdm.TokenUpdate) error { func (ms *MultiService) CheckOut(r *mdm.Request, m *mdm.CheckOut) error { err := ms.svcs[0].CheckOut(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { return svc.CheckOut(rc, m) }) return err @@ -84,7 +85,7 @@ func (ms *MultiService) CheckOut(r *mdm.Request, m *mdm.CheckOut) error { func (ms *MultiService) UserAuthenticate(r *mdm.Request, m *mdm.UserAuthenticate) ([]byte, error) { respBytes, err := ms.svcs[0].UserAuthenticate(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { _, err := svc.UserAuthenticate(rc, m) return err }) @@ -94,7 +95,7 @@ func (ms *MultiService) UserAuthenticate(r *mdm.Request, m *mdm.UserAuthenticate func (ms *MultiService) SetBootstrapToken(r *mdm.Request, m *mdm.SetBootstrapToken) error { err := ms.svcs[0].SetBootstrapToken(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { return svc.SetBootstrapToken(rc, m) }) return err @@ -103,7 +104,7 @@ func (ms *MultiService) SetBootstrapToken(r *mdm.Request, m *mdm.SetBootstrapTok func (ms *MultiService) GetBootstrapToken(r *mdm.Request, m *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { bsToken, err := ms.svcs[0].GetBootstrapToken(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { _, err := svc.GetBootstrapToken(rc, m) return err }) @@ -113,7 +114,7 @@ func (ms *MultiService) GetBootstrapToken(r *mdm.Request, m *mdm.GetBootstrapTok func (ms *MultiService) DeclarativeManagement(r *mdm.Request, m *mdm.DeclarativeManagement) ([]byte, error) { retBytes, err := ms.svcs[0].DeclarativeManagement(r, m) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { _, err := svc.DeclarativeManagement(rc, m) return err }) @@ -123,7 +124,7 @@ func (ms *MultiService) DeclarativeManagement(r *mdm.Request, m *mdm.Declarative func (ms *MultiService) CommandAndReportResults(r *mdm.Request, results *mdm.CommandResults) (*mdm.Command, error) { cmd, err := ms.svcs[0].CommandAndReportResults(r, results) rc := ms.RequestWithContext(r) - ms.runOthers(func(svc service.CheckinAndCommandService) error { + ms.runOthers(r.Context, func(svc service.CheckinAndCommandService) error { _, err := svc.CommandAndReportResults(rc, results) return err }) diff --git a/service/nanomdm/ctxlog.go b/service/nanomdm/ctxlog.go new file mode 100644 index 0000000..530a1e3 --- /dev/null +++ b/service/nanomdm/ctxlog.go @@ -0,0 +1,38 @@ +package nanomdm + +import ( + "context" + + "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" + "github.com/micromdm/nanomdm/mdm" +) + +type ( + ctxKeyID struct{} + ctxKeyType struct{} +) + +func newContext(ctx context.Context, r *mdm.Request) context.Context { + newCtx := context.WithValue(ctx, ctxKeyID{}, r.ID) + return context.WithValue(newCtx, ctxKeyType{}, r.Type) +} + +func ctxKVs(ctx context.Context) (out []interface{}) { + id, ok := ctx.Value(ctxKeyID{}).(string) + if ok { + out = append(out, "id", id) + } + eType, ok := ctx.Value(ctxKeyType{}).(mdm.EnrollType) + if ok { + out = append(out, "type", eType) + } + return +} + +// ctxLogger sets up and returns a new contextual logger +func (s *Service) ctxLogger(r *mdm.Request) log.Logger { + r.Context = newContext(r.Context, r) + r.Context = ctxlog.AddFunc(r.Context, ctxKVs) + return ctxlog.Logger(r.Context, s.logger) +} diff --git a/service/nanomdm/service.go b/service/nanomdm/service.go index b66418c..453aa99 100644 --- a/service/nanomdm/service.go +++ b/service/nanomdm/service.go @@ -97,15 +97,14 @@ func (s *Service) Authenticate(r *mdm.Request, message *mdm.Authenticate) error if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } + logger := s.ctxLogger(r) logs := []interface{}{ "msg", "Authenticate", - "id", r.ID, - "type", r.Type, } if message.SerialNumber != "" { logs = append(logs, "serial_number", message.SerialNumber) } - s.logger.Info(logs...) + logger.Info(logs...) if err := s.store.StoreAuthenticate(r, message); err != nil { return err } @@ -125,10 +124,9 @@ func (s *Service) TokenUpdate(r *mdm.Request, message *mdm.TokenUpdate) error { if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - s.logger.Info( + logger := s.ctxLogger(r) + logger.Info( "msg", "TokenUpdate", - "id", r.ID, - "type", r.Type, ) return s.store.StoreTokenUpdate(r, message) } @@ -138,10 +136,9 @@ func (s *Service) CheckOut(r *mdm.Request, message *mdm.CheckOut) error { if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - s.logger.Info( + logger := s.ctxLogger(r) + logger.Info( "msg", "CheckOut", - "id", r.ID, - "type", r.Type, ) return s.store.Disable(r) } @@ -161,6 +158,7 @@ func (s *Service) UserAuthenticate(r *mdm.Request, message *mdm.UserAuthenticate if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return nil, err } + logger := s.ctxLogger(r) if s.sendEmptyDigestChallenge || s.storeRejectedUserAuth { if err := s.store.StoreUserAuthenticate(r, message); err != nil { return nil, err @@ -170,10 +168,8 @@ func (s *Service) UserAuthenticate(r *mdm.Request, message *mdm.UserAuthenticate // UserAuthenticate messages depending on our response if message.DigestResponse == "" { if s.sendEmptyDigestChallenge { - s.logger.Info( + logger.Info( "msg", "sending empty DigestChallenge response to UserAuthenticate", - "id", r.ID, - "type", r.Type, ) return emptyDigestChallengeBytes, nil } @@ -182,10 +178,8 @@ func (s *Service) UserAuthenticate(r *mdm.Request, message *mdm.UserAuthenticate fmt.Errorf("declining management of user: %s", r.ID), ) } - s.logger.Debug( + logger.Debug( "msg", "sending empty response to second UserAuthenticate", - "id", r.ID, - "type", r.Type, ) return nil, nil } @@ -194,10 +188,9 @@ func (s *Service) SetBootstrapToken(r *mdm.Request, message *mdm.SetBootstrapTok if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - s.logger.Info( + logger := s.ctxLogger(r) + logger.Info( "msg", "SetBootstrapToken", - "id", r.ID, - "type", r.Type, ) return s.store.StoreBootstrapToken(r, message) } @@ -206,10 +199,9 @@ func (s *Service) GetBootstrapToken(r *mdm.Request, message *mdm.GetBootstrapTok if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return nil, err } - s.logger.Info( + logger := s.ctxLogger(r) + logger.Info( "msg", "GetBootstrapToken", - "id", r.ID, - "type", r.Type, ) return s.store.RetrieveBootstrapToken(r, message) } @@ -220,10 +212,9 @@ func (s *Service) DeclarativeManagement(r *mdm.Request, message *mdm.Declarative if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return nil, err } - s.logger.Info( + logger := s.ctxLogger(r) + logger.Info( "msg", "DeclarativeManagement", - "id", r.ID, - "type", r.Type, "endpoint", message.Endpoint, ) if s.dm == nil { @@ -237,15 +228,14 @@ func (s *Service) CommandAndReportResults(r *mdm.Request, results *mdm.CommandRe if err := s.updateEnrollID(r, &results.Enrollment); err != nil { return nil, err } + logger := s.ctxLogger(r) logs := []interface{}{ "status", results.Status, - "id", r.ID, - "type", r.Type, } if results.Status != "Idle" { logs = append(logs, "command_uuid", results.CommandUUID) } - s.logger.Info(logs...) + logger.Info(logs...) err := s.store.StoreCommandReport(r, results) if err != nil { return nil, fmt.Errorf("storing command report: %w", err) @@ -255,16 +245,14 @@ func (s *Service) CommandAndReportResults(r *mdm.Request, results *mdm.CommandRe return nil, fmt.Errorf("retrieving next command: %w", err) } if cmd != nil { - s.logger.Debug( + logger.Debug( "msg", "command retrieved", - "id", r.ID, "command_uuid", cmd.CommandUUID, ) return cmd, nil } - s.logger.Debug( + logger.Debug( "msg", "no command retrieved", - "id", r.ID, ) return nil, nil } diff --git a/storage/allmulti/allmulti.go b/storage/allmulti/allmulti.go index 16a0b71..c99268b 100644 --- a/storage/allmulti/allmulti.go +++ b/storage/allmulti/allmulti.go @@ -4,6 +4,7 @@ import ( "context" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/storage" ) @@ -32,7 +33,7 @@ type returnCollector struct { type errRunner func(storage.AllStorage) (interface{}, error) -func (ms *MultiAllStorage) execStores(r errRunner) (interface{}, error) { +func (ms *MultiAllStorage) execStores(ctx context.Context, r errRunner) (interface{}, error) { retChan := make(chan *returnCollector) for i, store := range ms.stores { go func(n int, s storage.AllStorage) { @@ -52,42 +53,45 @@ func (ms *MultiAllStorage) execStores(r errRunner) (interface{}, error) { finalErr = sErr.err finalValue = sErr.returnValue } else if sErr.err != nil { - ms.logger.Info("n", sErr.storeNumber, "err", sErr.err) + ctxlog.Logger(ctx, ms.logger).Info( + "n", sErr.storeNumber, + "err", sErr.err, + ) } } return finalValue, finalErr } func (ms *MultiAllStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.StoreAuthenticate(r, msg) }) return err } func (ms *MultiAllStorage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.StoreTokenUpdate(r, msg) }) return err } func (ms *MultiAllStorage) RetrieveTokenUpdateTally(ctx context.Context, id string) (int, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { return s.RetrieveTokenUpdateTally(ctx, id) }) return val.(int), err } func (ms *MultiAllStorage) StoreUserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.StoreUserAuthenticate(r, msg) }) return err } func (ms *MultiAllStorage) Disable(r *mdm.Request) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.Disable(r) }) return err diff --git a/storage/allmulti/bstoken.go b/storage/allmulti/bstoken.go index 2ea941c..4876852 100644 --- a/storage/allmulti/bstoken.go +++ b/storage/allmulti/bstoken.go @@ -6,14 +6,14 @@ import ( ) func (ms *MultiAllStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.StoreBootstrapToken(r, msg) }) return err } func (ms *MultiAllStorage) RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return s.RetrieveBootstrapToken(r, msg) }) return val.(*mdm.BootstrapToken), err diff --git a/storage/allmulti/certauth.go b/storage/allmulti/certauth.go index ca4e289..3593c6c 100644 --- a/storage/allmulti/certauth.go +++ b/storage/allmulti/certauth.go @@ -6,28 +6,28 @@ import ( ) func (ms *MultiAllStorage) HasCertHash(r *mdm.Request, hash string) (bool, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return s.HasCertHash(r, hash) }) return val.(bool), err } func (ms *MultiAllStorage) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return s.EnrollmentHasCertHash(r, hash) }) return val.(bool), err } func (ms *MultiAllStorage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return s.IsCertHashAssociated(r, hash) }) return val.(bool), err } func (ms *MultiAllStorage) AssociateCertHash(r *mdm.Request, hash string) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.AssociateCertHash(r, hash) }) return err diff --git a/storage/allmulti/push.go b/storage/allmulti/push.go index f82536b..919c4e9 100644 --- a/storage/allmulti/push.go +++ b/storage/allmulti/push.go @@ -8,7 +8,7 @@ import ( ) func (ms *MultiAllStorage) RetrievePushInfo(ctx context.Context, ids []string) (map[string]*mdm.Push, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { return s.RetrievePushInfo(ctx, ids) }) return val.(map[string]*mdm.Push), err diff --git a/storage/allmulti/pushcert.go b/storage/allmulti/pushcert.go index 1814fdf..f277524 100644 --- a/storage/allmulti/pushcert.go +++ b/storage/allmulti/pushcert.go @@ -8,7 +8,7 @@ import ( ) func (ms *MultiAllStorage) IsPushCertStale(ctx context.Context, topic string, staleToken string) (bool, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { return s.IsPushCertStale(ctx, topic, staleToken) }) return val.(bool), err @@ -20,7 +20,7 @@ type retrievePushCertReturns struct { } func (ms *MultiAllStorage) RetrievePushCert(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { rets := new(retrievePushCertReturns) var err error rets.cert, rets.staleToken, err = s.RetrievePushCert(ctx, topic) @@ -31,7 +31,7 @@ func (ms *MultiAllStorage) RetrievePushCert(ctx context.Context, topic string) ( } func (ms *MultiAllStorage) StorePushCert(ctx context.Context, pemCert, pemKey []byte) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { return nil, s.StorePushCert(ctx, pemCert, pemKey) }) return err diff --git a/storage/allmulti/queue.go b/storage/allmulti/queue.go index 087e5ce..2818c81 100644 --- a/storage/allmulti/queue.go +++ b/storage/allmulti/queue.go @@ -8,28 +8,28 @@ import ( ) func (ms *MultiAllStorage) StoreCommandReport(r *mdm.Request, report *mdm.CommandResults) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.StoreCommandReport(r, report) }) return err } func (ms *MultiAllStorage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return s.RetrieveNextCommand(r, skipNotNow) }) return val.(*mdm.Command), err } func (ms *MultiAllStorage) ClearQueue(r *mdm.Request) error { - _, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + _, err := ms.execStores(r.Context, func(s storage.AllStorage) (interface{}, error) { return nil, s.ClearQueue(r) }) return err } func (ms *MultiAllStorage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) { - val, err := ms.execStores(func(s storage.AllStorage) (interface{}, error) { + val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { return s.EnqueueCommand(ctx, id, cmd) }) return val.(map[string]error), err diff --git a/storage/mysql/mysql.go b/storage/mysql/mysql.go index c21187e..0132fcd 100644 --- a/storage/mysql/mysql.go +++ b/storage/mysql/mysql.go @@ -8,6 +8,7 @@ import ( "github.com/micromdm/nanomdm/cryptoutil" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" ) @@ -117,7 +118,9 @@ func (s *MySQLStorage) storeUserTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate // there shouldn't be an Unlock Token on the user channel, but // complain if there is to warn an admin if len(msg.UnlockToken) > 0 { - s.logger.Info("msg", "Unlock Token on user channel not stored") + ctxlog.Logger(r.Context, s.logger).Info( + "msg", "Unlock Token on user channel not stored", + ) } _, err := s.db.ExecContext( r.Context, ` From d9b78268e08e0f313ae2b4ab7468a4ce34acd834 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Thu, 17 Mar 2022 13:06:18 -0700 Subject: [PATCH 2/8] Check for safe context use --- log/ctxlog/ctxlog.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/log/ctxlog/ctxlog.go b/log/ctxlog/ctxlog.go index e73e523..a20557a 100644 --- a/log/ctxlog/ctxlog.go +++ b/log/ctxlog/ctxlog.go @@ -23,6 +23,9 @@ type funcs struct { // AddFunc associates a new CtxKVFunc function to a context. func AddFunc(ctx context.Context, f CtxKVFunc) context.Context { + if ctx == nil { + return ctx + } ctxFuncs, ok := ctx.Value(ctxKeyFuncs{}).(*funcs) if !ok || ctxFuncs == nil { ctxFuncs = &funcs{} @@ -36,6 +39,9 @@ func AddFunc(ctx context.Context, f CtxKVFunc) context.Context { // Logger runs the associated CtxKVFunc functions and returns a new // logger with the results. func Logger(ctx context.Context, logger log.Logger) log.Logger { + if ctx == nil { + return logger + } ctxFuncs, ok := ctx.Value(ctxKeyFuncs{}).(*funcs) if !ok { return logger From 6d468ed56938b245a20dbcd6007ad58a20b8de28 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Thu, 17 Mar 2022 14:00:26 -0700 Subject: [PATCH 3/8] Cleanup logging endpoints, collapse single assignments into direct accesses --- cmd/nanomdm/main.go | 3 +-- http/mdm_cert.go | 11 +++++++---- service/nanomdm/service.go | 31 ++++++++++--------------------- storage/allmulti/migrate.go | 10 ++++++++-- 4 files changed, 26 insertions(+), 29 deletions(-) diff --git a/cmd/nanomdm/main.go b/cmd/nanomdm/main.go index 3fd7b09..85781ca 100644 --- a/cmd/nanomdm/main.go +++ b/cmd/nanomdm/main.go @@ -243,7 +243,6 @@ func simpleLog(next http.Handler, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := storeNewTraceID(r.Context()) ctx = ctxlog.AddFunc(ctx, ctxlog.SimpleStringFunc(ctxKeyTraceID{}, "trace_id")) - logger := ctxlog.Logger(ctx, logger) host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr @@ -257,7 +256,7 @@ func simpleLog(next http.Handler, logger log.Logger) http.HandlerFunc { if fwdedFor := r.Header.Get("X-Forwarded-For"); fwdedFor != "" { logs = append(logs, "real_ip", fwdedFor) } - logger.Info(logs...) + ctxlog.Logger(ctx, logger).Info(logs...) next.ServeHTTP(w, r.WithContext(ctx)) } } diff --git a/http/mdm_cert.go b/http/mdm_cert.go index 1e92339..1e06a97 100644 --- a/http/mdm_cert.go +++ b/http/mdm_cert.go @@ -51,9 +51,10 @@ func CertExtractPEMHeaderMiddleware(next http.Handler, header string, logger log // at the TLS peer certificate in the request. func CertExtractTLSMiddleware(next http.Handler, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger := ctxlog.Logger(r.Context(), logger) if r.TLS == nil || len(r.TLS.PeerCertificates) < 1 { - logger.Debug("msg", "no TLS peer certificate") + ctxlog.Logger(r.Context(), logger).Debug( + "msg", "no TLS peer certificate", + ) next.ServeHTTP(w, r) return } @@ -115,9 +116,11 @@ type CertVerifier interface { // MDM unenrollments in the case of bugs or something going wrong. func CertVerifyMiddleware(next http.Handler, verifier CertVerifier, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger := ctxlog.Logger(r.Context(), logger) if err := verifier.Verify(GetCert(r.Context())); err != nil { - logger.Info("msg", "error verifying MDM certificate", "err", err) + ctxlog.Logger(r.Context(), logger).Info( + "msg", "error verifying MDM certificate", + "err", err, + ) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } diff --git a/service/nanomdm/service.go b/service/nanomdm/service.go index 453aa99..df09587 100644 --- a/service/nanomdm/service.go +++ b/service/nanomdm/service.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/service" "github.com/micromdm/nanomdm/storage" @@ -86,7 +87,9 @@ func New(store storage.ServiceStore, opts ...Option) *Service { func (s *Service) updateEnrollID(r *mdm.Request, e *mdm.Enrollment) error { if r.EnrollID != nil && r.ID != "" { - s.logger.Debug("msg", "overwriting enrollment id") + ctxlog.Logger(r.Context, s.logger).Debug( + "msg", "overwriting enrollment id", + ) } r.EnrollID = s.normalizer(e) return r.EnrollID.Validate() @@ -97,14 +100,13 @@ func (s *Service) Authenticate(r *mdm.Request, message *mdm.Authenticate) error if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - logger := s.ctxLogger(r) logs := []interface{}{ "msg", "Authenticate", } if message.SerialNumber != "" { logs = append(logs, "serial_number", message.SerialNumber) } - logger.Info(logs...) + s.ctxLogger(r).Info(logs...) if err := s.store.StoreAuthenticate(r, message); err != nil { return err } @@ -124,10 +126,7 @@ func (s *Service) TokenUpdate(r *mdm.Request, message *mdm.TokenUpdate) error { if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - logger := s.ctxLogger(r) - logger.Info( - "msg", "TokenUpdate", - ) + s.ctxLogger(r).Info("msg", "TokenUpdate") return s.store.StoreTokenUpdate(r, message) } @@ -136,10 +135,7 @@ func (s *Service) CheckOut(r *mdm.Request, message *mdm.CheckOut) error { if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - logger := s.ctxLogger(r) - logger.Info( - "msg", "CheckOut", - ) + s.ctxLogger(r).Info("msg", "CheckOut") return s.store.Disable(r) } @@ -188,10 +184,7 @@ func (s *Service) SetBootstrapToken(r *mdm.Request, message *mdm.SetBootstrapTok if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return err } - logger := s.ctxLogger(r) - logger.Info( - "msg", "SetBootstrapToken", - ) + s.ctxLogger(r).Info("msg", "SetBootstrapToken") return s.store.StoreBootstrapToken(r, message) } @@ -199,10 +192,7 @@ func (s *Service) GetBootstrapToken(r *mdm.Request, message *mdm.GetBootstrapTok if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return nil, err } - logger := s.ctxLogger(r) - logger.Info( - "msg", "GetBootstrapToken", - ) + s.ctxLogger(r).Info("msg", "GetBootstrapToken") return s.store.RetrieveBootstrapToken(r, message) } @@ -212,8 +202,7 @@ func (s *Service) DeclarativeManagement(r *mdm.Request, message *mdm.Declarative if err := s.updateEnrollID(r, &message.Enrollment); err != nil { return nil, err } - logger := s.ctxLogger(r) - logger.Info( + s.ctxLogger(r).Info( "msg", "DeclarativeManagement", "endpoint", message.Endpoint, ) diff --git a/storage/allmulti/migrate.go b/storage/allmulti/migrate.go index a75c873..51fb820 100644 --- a/storage/allmulti/migrate.go +++ b/storage/allmulti/migrate.go @@ -1,8 +1,14 @@ package allmulti -import "context" +import ( + "context" + + "github.com/micromdm/nanomdm/log/ctxlog" +) func (ms *MultiAllStorage) RetrieveMigrationCheckins(ctx context.Context, c chan<- interface{}) error { - ms.logger.Info("msg", "only using first store for migration") + ctxlog.Logger(ctx, ms.logger).Info( + "msg", "only using first store for migration", + ) return ms.stores[0].RetrieveMigrationCheckins(ctx, c) } From b63813312e5fda53ce091760039415c4253d46da Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Thu, 24 Mar 2022 10:23:45 -0700 Subject: [PATCH 4/8] Minor refactor of context logging in nanomdm service --- service/nanomdm/ctxlog.go | 11 +--------- service/nanomdm/service.go | 41 +++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/service/nanomdm/ctxlog.go b/service/nanomdm/ctxlog.go index 530a1e3..944a03e 100644 --- a/service/nanomdm/ctxlog.go +++ b/service/nanomdm/ctxlog.go @@ -3,8 +3,6 @@ package nanomdm import ( "context" - "github.com/micromdm/nanomdm/log" - "github.com/micromdm/nanomdm/log/ctxlog" "github.com/micromdm/nanomdm/mdm" ) @@ -13,7 +11,7 @@ type ( ctxKeyType struct{} ) -func newContext(ctx context.Context, r *mdm.Request) context.Context { +func newContextWithValues(ctx context.Context, r *mdm.Request) context.Context { newCtx := context.WithValue(ctx, ctxKeyID{}, r.ID) return context.WithValue(newCtx, ctxKeyType{}, r.Type) } @@ -29,10 +27,3 @@ func ctxKVs(ctx context.Context) (out []interface{}) { } return } - -// ctxLogger sets up and returns a new contextual logger -func (s *Service) ctxLogger(r *mdm.Request) log.Logger { - r.Context = newContext(r.Context, r) - r.Context = ctxlog.AddFunc(r.Context, ctxKVs) - return ctxlog.Logger(r.Context, s.logger) -} diff --git a/service/nanomdm/service.go b/service/nanomdm/service.go index df09587..9b2ce5e 100644 --- a/service/nanomdm/service.go +++ b/service/nanomdm/service.go @@ -85,19 +85,24 @@ func New(store storage.ServiceStore, opts ...Option) *Service { return nanomdm } -func (s *Service) updateEnrollID(r *mdm.Request, e *mdm.Enrollment) error { +func (s *Service) setupRequest(r *mdm.Request, e *mdm.Enrollment) error { if r.EnrollID != nil && r.ID != "" { ctxlog.Logger(r.Context, s.logger).Debug( "msg", "overwriting enrollment id", ) } r.EnrollID = s.normalizer(e) - return r.EnrollID.Validate() + if err := r.EnrollID.Validate(); err != nil { + return err + } + r.Context = newContextWithValues(r.Context, r) + r.Context = ctxlog.AddFunc(r.Context, ctxKVs) + return nil } // Authenticate Check-in message implementation. func (s *Service) Authenticate(r *mdm.Request, message *mdm.Authenticate) error { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return err } logs := []interface{}{ @@ -106,7 +111,7 @@ func (s *Service) Authenticate(r *mdm.Request, message *mdm.Authenticate) error if message.SerialNumber != "" { logs = append(logs, "serial_number", message.SerialNumber) } - s.ctxLogger(r).Info(logs...) + ctxlog.Logger(r.Context, s.logger).Info(logs...) if err := s.store.StoreAuthenticate(r, message); err != nil { return err } @@ -123,19 +128,19 @@ func (s *Service) Authenticate(r *mdm.Request, message *mdm.Authenticate) error // TokenUpdate Check-in message implementation. func (s *Service) TokenUpdate(r *mdm.Request, message *mdm.TokenUpdate) error { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return err } - s.ctxLogger(r).Info("msg", "TokenUpdate") + ctxlog.Logger(r.Context, s.logger).Info("msg", "TokenUpdate") return s.store.StoreTokenUpdate(r, message) } // CheckOut Check-in message implementation. func (s *Service) CheckOut(r *mdm.Request, message *mdm.CheckOut) error { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return err } - s.ctxLogger(r).Info("msg", "CheckOut") + ctxlog.Logger(r.Context, s.logger).Info("msg", "CheckOut") return s.store.Disable(r) } @@ -151,10 +156,10 @@ const emptyDigestChallenge = ` var emptyDigestChallengeBytes = []byte(emptyDigestChallenge) func (s *Service) UserAuthenticate(r *mdm.Request, message *mdm.UserAuthenticate) ([]byte, error) { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return nil, err } - logger := s.ctxLogger(r) + logger := ctxlog.Logger(r.Context, s.logger) if s.sendEmptyDigestChallenge || s.storeRejectedUserAuth { if err := s.store.StoreUserAuthenticate(r, message); err != nil { return nil, err @@ -181,28 +186,28 @@ func (s *Service) UserAuthenticate(r *mdm.Request, message *mdm.UserAuthenticate } func (s *Service) SetBootstrapToken(r *mdm.Request, message *mdm.SetBootstrapToken) error { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return err } - s.ctxLogger(r).Info("msg", "SetBootstrapToken") + ctxlog.Logger(r.Context, s.logger).Info("msg", "SetBootstrapToken") return s.store.StoreBootstrapToken(r, message) } func (s *Service) GetBootstrapToken(r *mdm.Request, message *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return nil, err } - s.ctxLogger(r).Info("msg", "GetBootstrapToken") + ctxlog.Logger(r.Context, s.logger).Info("msg", "GetBootstrapToken") return s.store.RetrieveBootstrapToken(r, message) } // DeclarativeManagement Check-in message implementation. Calls out to // the service's DM handler (if configured). func (s *Service) DeclarativeManagement(r *mdm.Request, message *mdm.DeclarativeManagement) ([]byte, error) { - if err := s.updateEnrollID(r, &message.Enrollment); err != nil { + if err := s.setupRequest(r, &message.Enrollment); err != nil { return nil, err } - s.ctxLogger(r).Info( + ctxlog.Logger(r.Context, s.logger).Info( "msg", "DeclarativeManagement", "endpoint", message.Endpoint, ) @@ -214,10 +219,10 @@ func (s *Service) DeclarativeManagement(r *mdm.Request, message *mdm.Declarative // CommandAndReportResults command report and next-command request implementation. func (s *Service) CommandAndReportResults(r *mdm.Request, results *mdm.CommandResults) (*mdm.Command, error) { - if err := s.updateEnrollID(r, &results.Enrollment); err != nil { + if err := s.setupRequest(r, &results.Enrollment); err != nil { return nil, err } - logger := s.ctxLogger(r) + logger := ctxlog.Logger(r.Context, s.logger) logs := []interface{}{ "status", results.Status, } From 447e183ac91f20609941b4b6f10ac6b6d30a89e3 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Thu, 24 Mar 2022 10:28:43 -0700 Subject: [PATCH 5/8] Init new vs append --- log/ctxlog/ctxlog.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/log/ctxlog/ctxlog.go b/log/ctxlog/ctxlog.go index a20557a..9f420da 100644 --- a/log/ctxlog/ctxlog.go +++ b/log/ctxlog/ctxlog.go @@ -61,7 +61,7 @@ func SimpleStringFunc(ctxKey interface{}, logKey string) CtxKVFunc { return func(ctx context.Context) (out []interface{}) { v, _ := ctx.Value(ctxKey).(string) if v != "" { - out = append(out, logKey, v) + out = []interface{}{logKey, v} } return } From 02d3ea6106e828ed65839071b523982dfeeb9667 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Sat, 2 Apr 2022 13:18:23 -0700 Subject: [PATCH 6/8] Add some docs around CtxKVFunc --- log/ctxlog/ctxlog.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/log/ctxlog/ctxlog.go b/log/ctxlog/ctxlog.go index 9f420da..e15343e 100644 --- a/log/ctxlog/ctxlog.go +++ b/log/ctxlog/ctxlog.go @@ -9,6 +9,10 @@ import ( ) // CtxKVFunc creates logger key-value pairs from a context. +// CtxKVFuncs should aim to be be as efficient as possible—ideally only +// doing the minimum to read context values and generate KV pairs. Each +// associated CtxKVFunc is called every time we adapt a logger with +// Logger. type CtxKVFunc func(context.Context) []interface{} // ctxKeyFuncs is the context key for storing and retriveing From 209d878cbaf76d499418202eb310fb17bc56782c Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Sat, 2 Apr 2022 13:19:07 -0700 Subject: [PATCH 7/8] Check nil pointer --- log/ctxlog/ctxlog.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/log/ctxlog/ctxlog.go b/log/ctxlog/ctxlog.go index e15343e..fbe29db 100644 --- a/log/ctxlog/ctxlog.go +++ b/log/ctxlog/ctxlog.go @@ -47,7 +47,7 @@ func Logger(ctx context.Context, logger log.Logger) log.Logger { return logger } ctxFuncs, ok := ctx.Value(ctxKeyFuncs{}).(*funcs) - if !ok { + if !ok || ctxFuncs == nil { return logger } var acc []interface{} From e1eb998a1cae7d3ba79caee511bdfca00a5a28c7 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Sat, 2 Apr 2022 13:22:31 -0700 Subject: [PATCH 8/8] Invert func params --- cmd/nanomdm/main.go | 2 +- log/ctxlog/ctxlog.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/nanomdm/main.go b/cmd/nanomdm/main.go index 85781ca..cad0fd9 100644 --- a/cmd/nanomdm/main.go +++ b/cmd/nanomdm/main.go @@ -242,7 +242,7 @@ func storeNewTraceID(ctx context.Context) context.Context { func simpleLog(next http.Handler, logger log.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := storeNewTraceID(r.Context()) - ctx = ctxlog.AddFunc(ctx, ctxlog.SimpleStringFunc(ctxKeyTraceID{}, "trace_id")) + ctx = ctxlog.AddFunc(ctx, ctxlog.SimpleStringFunc("trace_id", ctxKeyTraceID{})) host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr diff --git a/log/ctxlog/ctxlog.go b/log/ctxlog/ctxlog.go index fbe29db..b7d4fbb 100644 --- a/log/ctxlog/ctxlog.go +++ b/log/ctxlog/ctxlog.go @@ -61,7 +61,7 @@ func Logger(ctx context.Context, logger log.Logger) log.Logger { // SimpleStringFunc is a helper that generates a simple CtxKVFunc that // returns a key-value pair if found on the context. -func SimpleStringFunc(ctxKey interface{}, logKey string) CtxKVFunc { +func SimpleStringFunc(logKey string, ctxKey interface{}) CtxKVFunc { return func(ctx context.Context) (out []interface{}) { v, _ := ctx.Value(ctxKey).(string) if v != "" {