Skip to content

Commit

Permalink
Rearrange tlsconfig trace args (#183)
Browse files Browse the repository at this point in the history
Allows for future expansion and aligns params for consistency

Signed-off-by: Andrew Harding <[email protected]>
  • Loading branch information
azdagron authored Mar 10, 2022
1 parent 605de4a commit 4545801
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
6 changes: 3 additions & 3 deletions v2/spiffetls/tlsconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ func WrapVerifyPeerCertificate(wrapped func([][]byte, [][]*x509.Certificate) err
func getTLSCertificate(svid x509svid.Source, trace Trace) (*tls.Certificate, error) {
var traceVal interface{}
if trace.GetCertificate != nil {
traceVal = trace.GetCertificate()
traceVal = trace.GetCertificate(GetCertificateInfo{})
}

s, err := svid.GetX509SVID()
if err != nil {
if trace.GotCertificate != nil {
trace.GotCertificate(traceVal, GotCertificateInfo{Err: err})
trace.GotCertificate(GotCertificateInfo{Err: err}, traceVal)
}
return nil, err
}
Expand All @@ -228,7 +228,7 @@ func getTLSCertificate(svid x509svid.Source, trace Trace) (*tls.Certificate, err
}

if trace.GotCertificate != nil {
trace.GotCertificate(traceVal, GotCertificateInfo{Cert: cert})
trace.GotCertificate(GotCertificateInfo{Cert: cert}, traceVal)
}

return cert, nil
Expand Down
8 changes: 4 additions & 4 deletions v2/spiffetls/tlsconfig/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import (
)

var localTrace = tlsconfig.Trace{
GetCertificate: func() interface{} {
GetCertificate: func(tlsconfig.GetCertificateInfo) interface{} {
fmt.Printf("got start of GetTLSCertificate\n")
return nil
},
GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) {
GotCertificate: func(tlsconfig.GotCertificateInfo, interface{}) {
fmt.Printf("got end of GetTLSCertificate\n")
},
}
Expand Down Expand Up @@ -263,13 +263,13 @@ func TestHookMTLSWebServerConfig(t *testing.T) {

func hookedTracer(onGetCertificate, onGotCertificate func()) tlsconfig.Trace {
return tlsconfig.Trace{
GetCertificate: func() interface{} {
GetCertificate: func(tlsconfig.GetCertificateInfo) interface{} {
if onGetCertificate != nil {
onGetCertificate()
}
return nil
},
GotCertificate: func(interface{}, tlsconfig.GotCertificateInfo) {
GotCertificate: func(tlsconfig.GotCertificateInfo, interface{}) {
if onGotCertificate != nil {
onGotCertificate()
}
Expand Down
8 changes: 6 additions & 2 deletions v2/spiffetls/tlsconfig/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import (
"crypto/tls"
)

// GetCertificateInfo is an empty placeholder for future expansion
type GetCertificateInfo struct {
}

// GotCertificateInfo provides err and TLS certificate info to Trace
type GotCertificateInfo struct {
Cert *tls.Certificate
Expand All @@ -13,6 +17,6 @@ type GotCertificateInfo struct {
// Trace is the interface to define what functions are triggered when functions
// in tlsconfig are called
type Trace struct {
GetCertificate func() interface{}
GotCertificate func(interface{}, GotCertificateInfo)
GetCertificate func(GetCertificateInfo) interface{}
GotCertificate func(GotCertificateInfo, interface{})
}

0 comments on commit 4545801

Please sign in to comment.