From 4545801a641121dde6c9f4c636b61ee89c12e9f5 Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Thu, 10 Mar 2022 15:27:22 -0700 Subject: [PATCH] Rearrange tlsconfig trace args (#183) Allows for future expansion and aligns params for consistency Signed-off-by: Andrew Harding --- v2/spiffetls/tlsconfig/config.go | 6 +++--- v2/spiffetls/tlsconfig/config_test.go | 8 ++++---- v2/spiffetls/tlsconfig/trace.go | 8 ++++++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/v2/spiffetls/tlsconfig/config.go b/v2/spiffetls/tlsconfig/config.go index a4e69ab9..53b36ed0 100644 --- a/v2/spiffetls/tlsconfig/config.go +++ b/v2/spiffetls/tlsconfig/config.go @@ -207,13 +207,13 @@ func WrapVerifyPeerCertificate(wrapped func([][]byte, [][]*x509.Certificate) err func getTLSCertificate(svid x509svid.Source, trace Trace) (*tls.Certificate, error) { var traceVal interface{} if trace.GetCertificate != nil { - traceVal = trace.GetCertificate() + traceVal = trace.GetCertificate(GetCertificateInfo{}) } s, err := svid.GetX509SVID() if err != nil { if trace.GotCertificate != nil { - trace.GotCertificate(traceVal, GotCertificateInfo{Err: err}) + trace.GotCertificate(GotCertificateInfo{Err: err}, traceVal) } return nil, err } @@ -228,7 +228,7 @@ func getTLSCertificate(svid x509svid.Source, trace Trace) (*tls.Certificate, err } if trace.GotCertificate != nil { - trace.GotCertificate(traceVal, GotCertificateInfo{Cert: cert}) + trace.GotCertificate(GotCertificateInfo{Cert: cert}, traceVal) } return cert, nil diff --git a/v2/spiffetls/tlsconfig/config_test.go b/v2/spiffetls/tlsconfig/config_test.go index c8a3039f..cb8d76b9 100644 --- a/v2/spiffetls/tlsconfig/config_test.go +++ b/v2/spiffetls/tlsconfig/config_test.go @@ -21,11 +21,11 @@ import ( ) var localTrace = tlsconfig.Trace{ - GetCertificate: func() interface{} { + GetCertificate: func(tlsconfig.GetCertificateInfo) interface{} { fmt.Printf("got start of GetTLSCertificate\n") return nil }, - GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) { + GotCertificate: func(tlsconfig.GotCertificateInfo, interface{}) { fmt.Printf("got end of GetTLSCertificate\n") }, } @@ -263,13 +263,13 @@ func TestHookMTLSWebServerConfig(t *testing.T) { func hookedTracer(onGetCertificate, onGotCertificate func()) tlsconfig.Trace { return tlsconfig.Trace{ - GetCertificate: func() interface{} { + GetCertificate: func(tlsconfig.GetCertificateInfo) interface{} { if onGetCertificate != nil { onGetCertificate() } return nil }, - GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) { + GotCertificate: func(tlsconfig.GotCertificateInfo, interface{}) { if onGotCertificate != nil { onGotCertificate() } diff --git a/v2/spiffetls/tlsconfig/trace.go b/v2/spiffetls/tlsconfig/trace.go index 0fdd1269..954d3945 100644 --- a/v2/spiffetls/tlsconfig/trace.go +++ b/v2/spiffetls/tlsconfig/trace.go @@ -4,6 +4,10 @@ import ( "crypto/tls" ) +// GetCertificateInfo is an empty placeholder for future expansion +type GetCertificateInfo struct { +} + // GotCertificateInfo provides err and TLS certificate info to Trace type GotCertificateInfo struct { Cert *tls.Certificate @@ -13,6 +17,6 @@ type GotCertificateInfo struct { // Trace is the interface to define what functions are triggered when functions // in tlsconfig are called type Trace struct { - GetCertificate func() interface{} - GotCertificate func(interface{}, GotCertificateInfo) + GetCertificate func(GetCertificateInfo) interface{} + GotCertificate func(GotCertificateInfo, interface{}) }