From 91c29a2d47e1a3677a5f53f24a32c27fcb75d8ef Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Tue, 26 Sep 2023 00:45:02 +0100 Subject: [PATCH] test: Add and clean up lint checks --- .github/workflows/build.yml | 25 +++- .golangci.yaml | 87 ++++++++++++++ .pre-commit-config.yaml | 9 ++ dh/dh.go | 162 ++++++++++++------------- gss/apcera.go | 106 ++++++++++------- gss/apcera_test.go | 15 ++- gss/export_test.go | 3 + gss/gokrb5.go | 21 ++-- gss/gokrb5_test.go | 11 +- gss/gss.go | 155 ++++++++++++------------ gss/gss_internal_test.go | 26 ++++ gss/gss_test.go | 178 ++++++++++++++++------------ gss/sspi.go | 68 +++++------ gss/sspi_test.go | 31 +++++ hmac.go | 13 +- hmac_test.go | 155 +++++++++++++++++------- internal/util/util.go | 29 +++-- internal/util/util_internal_test.go | 29 +++++ internal/util/util_test.go | 114 +++++++++++++----- multi.go | 22 ++-- multi_test.go | 72 +++++++---- tsig.go | 5 +- 22 files changed, 874 insertions(+), 462 deletions(-) create mode 100644 .golangci.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 gss/export_test.go create mode 100644 gss/gss_internal_test.go create mode 100644 gss/sspi_test.go create mode 100644 internal/util/util_internal_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a98be51..b523fb1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,11 +40,26 @@ jobs: with: go-version: ${{ matrix.go }} - #- name: golangci-lint - # uses: golangci/golangci-lint-action@v3 - # if: github.event_name == 'pull_request' - # with: - # only-new-issues: true + - name: golangci-lint (gokrb5) + uses: golangci/golangci-lint-action@v3 + if: github.event_name == 'pull_request' + with: + only-new-issues: true + + - name: golangci-lint (apcera) + uses: golangci/golangci-lint-action@v3 + if: github.event_name == 'pull_request' + with: + only-new-issues: true + args: --build-tags apcera + + - name: golangci-lint (SSPI) + uses: golangci/golangci-lint-action@v3 + if: github.event_name == 'pull_request' + with: + only-new-issues: true + env: + GOOS: windows - name: Install Kerberos client run: | diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..d312b6f --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,87 @@ +--- +issues: + exclude-use-default: false +linters: + disable-all: true + enable: + - asasalint + - asciicheck + - bidichk + - bodyclose + - containedctx + - contextcheck + - cyclop + - decorder + - dogsled + - dupl + - dupword + - durationcheck + - errcheck + - errchkjson + - errname + - errorlint + - execinquery + - exhaustive + - exportloopref + - forbidigo + - forcetypeassert + - funlen + - gci + - gochecknoglobals + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - gofmt + - gofumpt + - goheader + - goimports + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - gosimple + - govet + - grouper + - importas + - ineffassign + - interfacebloat + - lll + - loggercheck + - maintidx + - makezero + - misspell + - nakedret + - nestif + - nilerr + - nilnil + - nlreturn + - noctx + - nolintlint + - nosprintfhostport + - paralleltest + - prealloc + - predeclared + - promlinter + - reassign + - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - stylecheck + - tagliatelle + - tenv + - testableexamples + - testpackage + - thelper + - tparallel + - typecheck + - unconvert + - unparam + - unused + - usestdlibvars + - wastedassign + - whitespace + - wsl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a1c84f8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/commitizen-tools/commitizen + rev: v3.5.3 + hooks: + - id: commitizen + - repo: https://github.com/golangci/golangci-lint + rev: v1.54.1 + hooks: + - id: golangci-lint diff --git a/dh/dh.go b/dh/dh.go index 0bc1279..c2c13e1 100644 --- a/dh/dh.go +++ b/dh/dh.go @@ -3,69 +3,69 @@ Package dh implements RFC 2930 Diffie-Hellman key exchange functions. Example client: - import ( - "fmt" - "time" - - "github.com/bodgit/tsig/dh" - "github.com/miekg/dns" - ) - - func main() { - dnsClient := new(dns.Client) - dnsClient.Net = "tcp" - dnsClient.TsigSecret = map[string]string{"tsig.example.com.": "k9uK5qsPfbBxvVuldwzYww=="} - - dhClient, err := dh.NewClient(dnsClient) - if err != nil { - panic(err) - } - defer dhClient.Close() - - host := "ns.example.com:53" - - // Negotiate a key with the chosen server - keyname, mac, _, err := dhClient.NegotiateKey(host, "tsig.example.com.", dns.HmacMD5, "k9uK5qsPfbBxvVuldwzYww==") - if err != nil { - panic(err) - } - - dnsClient.TsigSecret[keyname] = mac - - // Use the DNS client as normal - - msg := new(dns.Msg) - msg.SetUpdate(dns.Fqdn("example.com")) - - insert, err := dns.NewRR("test.example.com. 300 A 192.0.2.1") - if err != nil { - panic(err) - } - msg.Insert([]dns.RR{insert}) - - msg.SetTsig(keyname, dns.HmacMD5, 300, time.Now().Unix()) - - rr, _, err := dnsClient.Exchange(msg, host) - if err != nil { - panic(err) - } - - if rr.Rcode != dns.RcodeSuccess { - fmt.Printf("DNS error: %s (%d)\n", dns.RcodeToString[rr.Rcode], rr.Rcode) - } - - // Revoke the key - err = dhClient.DeleteKey(keyname) - if err != nil { - panic(err) - } - } + import ( + "fmt" + "time" + + "github.com/bodgit/tsig/dh" + "github.com/miekg/dns" + ) + + func main() { + dnsClient := new(dns.Client) + dnsClient.Net = "tcp" + dnsClient.TsigSecret = map[string]string{"tsig.example.com.": "k9uK5qsPfbBxvVuldwzYww=="} + + dhClient, err := dh.NewClient(dnsClient) + if err != nil { + panic(err) + } + defer dhClient.Close() + + host := "ns.example.com:53" + + // Negotiate a key with the chosen server + keyname, mac, _, err := dhClient.NegotiateKey(host, "tsig.example.com.", dns.HmacMD5, "k9uK5qsPfbBxvVuldwzYww==") + if err != nil { + panic(err) + } + + dnsClient.TsigSecret[keyname] = mac + + // Use the DNS client as normal + + msg := new(dns.Msg) + msg.SetUpdate(dns.Fqdn("example.com")) + + insert, err := dns.NewRR("test.example.com. 300 A 192.0.2.1") + if err != nil { + panic(err) + } + msg.Insert([]dns.RR{insert}) + + msg.SetTsig(keyname, dns.HmacMD5, 300, time.Now().Unix()) + + rr, _, err := dnsClient.Exchange(msg, host) + if err != nil { + panic(err) + } + + if rr.Rcode != dns.RcodeSuccess { + fmt.Printf("DNS error: %s (%d)\n", dns.RcodeToString[rr.Rcode], rr.Rcode) + } + + // Revoke the key + err = dhClient.DeleteKey(keyname) + if err != nil { + panic(err) + } + } */ package dh import ( "bytes" - "crypto/md5" + "crypto/md5" //nolint:gosec "crypto/rand" "encoding/base64" "encoding/binary" @@ -85,7 +85,7 @@ import ( ) const ( - // RFC 2409, section 6.2 + // RFC 2409, section 6.2. modp1024 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1" + "29024E088A67CC74020BBEA63B139B22514A08798E3404DD" + "EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245" + @@ -111,7 +111,6 @@ type Client struct { } func dhGroup(group int) (*dh.Group, error) { - switch group { case 2: p, _ := new(big.Int).SetString(modp1024, 16) @@ -121,7 +120,7 @@ func dhGroup(group int) (*dh.Group, error) { G: new(big.Int).SetInt64(2), }, nil default: - return nil, fmt.Errorf("Unsupported DH group %v", group) + return nil, fmt.Errorf("unsupported DH group %v", group) } } @@ -129,7 +128,6 @@ func dhGroup(group int) (*dh.Group, error) { // It returns a context handle for any further functions along with any error // that occurred. func NewClient(dnsClient *dns.Client) (*Client, error) { - client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -147,36 +145,36 @@ func NewClient(dnsClient *dns.Client) (*Client, error) { // necessary. // It returns any error that occurred. func (c *Client) Close() error { - c.m.Lock() + keys := make([]string, 0, len(c.ctx)) for k := range c.ctx { keys = append(keys, k) } + c.m.Unlock() - var errs error + var errs *multierror.Error for _, k := range keys { errs = multierror.Append(errs, c.DeleteKey(k)) } - return errs + return errs.ErrorOrNil() } func readDHKey(raw []byte) (*dhkey, error) { - var key dhkey r := bytes.NewBuffer(raw) - var len uint16 + var l uint16 for _, f := range []*[]byte{&key.prime, &key.generator, &key.key} { - err := binary.Read(r, binary.BigEndian, &len) + err := binary.Read(r, binary.BigEndian, &l) if err != nil { return nil, err } - *f = make([]byte, len) + *f = make([]byte, l) if _, err = io.ReadFull(r, *f); err != nil { return nil, err } @@ -186,13 +184,12 @@ func readDHKey(raw []byte) (*dhkey, error) { } func writeDHKey(key *dhkey) ([]byte, error) { - w := new(bytes.Buffer) for _, f := range []*[]byte{&key.prime, &key.generator, &key.key} { - len := uint16(len(*f)) + l := uint16(len(*f)) - err := binary.Write(w, binary.BigEndian, len) + err := binary.Write(w, binary.BigEndian, l) if err != nil { return nil, err } @@ -206,26 +203,27 @@ func writeDHKey(key *dhkey) ([]byte, error) { } func computeMD5(nonce, secret []byte) []byte { - + //nolint:gosec checksum := md5.Sum(append(nonce, secret...)) return checksum[:] } func computeDHKey(ourNonce, peerNonce, secret []byte) []byte { - operand := append(computeMD5(ourNonce, secret), computeMD5(peerNonce, secret)...) var result []byte if len(secret) > len(operand) { result = make([]byte, len(secret)) copy(result, secret) + for i := 0; i < len(operand); i++ { result[i] ^= operand[i] } } else { result = make([]byte, len(operand)) copy(result, operand) + for i := 0; i < len(secret); i++ { result[i] ^= secret[i] } @@ -239,8 +237,9 @@ func computeDHKey(ourNonce, peerNonce, secret []byte) []byte { // algorithm and MAC. // It returns the negotiated TKEY name, MAC, expiry time, and any error that // occurred. +// +//nolint:cyclop,funlen func (c *Client) NegotiateKey(host, name, algorithm, mac string) (string, string, time.Time, error) { - keyname := "." g, err := dhGroup(2) @@ -288,15 +287,16 @@ func (c *Client) NegotiateKey(host, name, algorithm, mac string) (string, string c.client.TsigSecret[name] = mac defer delete(c.client.TsigSecret, name) + //nolint:lll tkey, keys, err := util.ExchangeTKEY(c.client, host, keyname, dns.HmacMD5, util.TkeyModeDH, 3600, an, extra, name, algorithm) if err != nil { return "", "", time.Time{}, err } var bkey []byte + for _, k := range keys { - switch key := k.(type) { - case *dns.KEY: + if key, ok := k.(*dns.KEY); ok { if key.Header().Name != keyname && key.Algorithm == dns.DH { if bkey, err = base64.StdEncoding.DecodeString(key.PublicKey); err != nil { return "", "", time.Time{}, err @@ -306,13 +306,14 @@ func (c *Client) NegotiateKey(host, name, algorithm, mac string) (string, string } if bkey == nil { - return "", "", time.Time{}, errors.New("No peer KEY record") + return "", "", time.Time{}, errors.New("no peer KEY record") } bdh, err := readDHKey(bkey) if err != nil { return "", "", time.Time{}, err } + by := new(big.Int).SetBytes(bdh.key) err = g.Check(by) @@ -347,19 +348,20 @@ func (c *Client) NegotiateKey(host, name, algorithm, mac string) (string, string // DeleteKey revokes the active key associated with the given TKEY name. // It returns any error that occurred. func (c *Client) DeleteKey(keyname string) error { - c.m.Lock() defer c.m.Unlock() ctx, ok := c.ctx[keyname] if !ok { - return errors.New("No such context") + return errors.New("no such context") } c.client.TsigSecret[keyname] = ctx.mac defer delete(c.client.TsigSecret, keyname) // Delete the key, signing the query with the key itself + // + //nolint:lll if _, _, err := util.ExchangeTKEY(c.client, ctx.host, keyname, ctx.algorithm, util.TkeyModeDelete, 0, nil, nil, keyname, ctx.algorithm); err != nil { return err } diff --git a/gss/apcera.go b/gss/apcera.go index 3c38e55..f9cd6ce 100644 --- a/gss/apcera.go +++ b/gss/apcera.go @@ -5,7 +5,6 @@ package gss import ( "encoding/hex" - "errors" "net" "sync" "time" @@ -28,18 +27,17 @@ type Client struct { logger logr.Logger } -// WithConfig sets the Kerberos configuration used -func WithConfig(config string) func(*Client) error { +// WithConfig sets the Kerberos configuration used. +func WithConfig(_ string) func(*Client) error { return func(c *Client) error { return errNotSupported } } -// New performs any library initialization necessary. +// NewClient performs any library initialization necessary. // It returns a context handle for any further functions along with any error // that occurred. func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, error) { - client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -70,7 +68,6 @@ func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, // necessary. // It returns any error that occurred. func (c *Client) Close() error { - return multierror.Append(c.close(), c.lib.Unload()) } @@ -79,8 +76,7 @@ func (c *Client) Close() error { // record containing the algorithm and name which is the negotiated TKEY // for this context. // It returns the bytes for the TSIG MAC and any error that occurred. -func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { - +func (c *Client) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) { if dns.CanonicalName(t.Algorithm) != tsig.GSS { return nil, dns.ErrKeyAlg } @@ -97,13 +93,19 @@ func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { if err != nil { return nil, err } - defer message.Release() + + defer func() { + err = multierror.Append(err, message.Release()).ErrorOrNil() + }() token, err := ctx.GetMIC(gssapi.GSS_C_QOP_DEFAULT, message) if err != nil { return nil, err } - defer token.Release() + + defer func() { + err = multierror.Append(err, token.Release()).ErrorOrNil() + }() return token.Bytes(), nil } @@ -113,8 +115,7 @@ func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { // containing the algorithm, MAC, and name which is the negotiated TKEY // for this context. // It returns any error that occurred. -func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { - +func (c *Client) Verify(stripped []byte, t *dns.TSIG) (err error) { if dns.CanonicalName(t.Algorithm) != tsig.GSS { return dns.ErrKeyAlg } @@ -132,7 +133,10 @@ func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { if err != nil { return err } - defer message.Release() + + defer func() { + err = multierror.Append(err, message.Release()).ErrorOrNil() + }() mac, err := hex.DecodeString(t.MAC) if err != nil { @@ -144,7 +148,10 @@ func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { if err != nil { return err } - defer token.Release() + + defer func() { + err = multierror.Append(err, token.Release()).ErrorOrNil() + }() // This is the actual verification bit if _, err = ctx.VerifyMIC(message, token); err != nil { @@ -158,30 +165,42 @@ func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { // server to establish a security context using the current user. // It returns the negotiated TKEY name, expiration time, and any error that // occurred. -func (c *Client) NegotiateContext(host string) (string, time.Time, error) { - +// +//nolint:cyclop,funlen +func (c *Client) NegotiateContext(host string) (keyname string, expiry time.Time, err error) { hostname, _, err := net.SplitHostPort(host) if err != nil { return "", time.Time{}, err } - keyname := generateTKEYName(hostname) + keyname, err = generateTKEYName(hostname) + if err != nil { + return "", time.Time{}, err + } buffer, err := c.lib.MakeBufferString(generateSPN(hostname)) if err != nil { return "", time.Time{}, err } - defer buffer.Release() + + defer func() { + err = multierror.Append(err, buffer.Release()).ErrorOrNil() + }() service, err := buffer.Name(c.lib.GSS_KRB5_NT_PRINCIPAL_NAME) if err != nil { return "", time.Time{}, err } - defer service.Release() - var input *gssapi.Buffer - var ctx *gssapi.CtxId - var tkey *dns.TKEY + defer func() { + err = multierror.Append(err, service.Release()).ErrorOrNil() + }() + + var ( + input *gssapi.Buffer + ctx *gssapi.CtxId + tkey *dns.TKEY + ) for ok := true; ok; ok = c.lib.LastStatus.Major.ContinueNeeded() { nctx, _, output, _, _, err := c.lib.InitSecContext( @@ -193,8 +212,13 @@ func (c *Client) NegotiateContext(host string) (string, time.Time, error) { 0, c.lib.GSS_C_NO_CHANNEL_BINDINGS, input) - defer output.Release() + ctx = nctx + + defer func() { + err = multierror.Append(err, output.Release()).ErrorOrNil() + }() + if err != nil { if !c.lib.LastStatus.Major.ContinueNeeded() { return "", time.Time{}, err @@ -204,37 +228,30 @@ func (c *Client) NegotiateContext(host string) (string, time.Time, error) { break } - var errs error - - // We don't care about non-TKEY answers, no additional RR's to send, and no signing + //nolint:lll if tkey, _, err = util.ExchangeTKEY(c.client, host, keyname, tsig.GSS, util.TkeyModeGSS, 3600, output.Bytes(), nil, "", ""); err != nil { - errs = multierror.Append(errs, err) - errs = multierror.Append(errs, ctx.DeleteSecContext()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(err, ctx.DeleteSecContext()) } if tkey.Header().Name != keyname { - errs = multierror.Append(errs, errors.New("TKEY name does not match")) - errs = multierror.Append(errs, ctx.DeleteSecContext()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(errDoesNotMatch, ctx.DeleteSecContext()) } key, err := hex.DecodeString(tkey.Key) if err != nil { - errs = multierror.Append(errs, err) - errs = multierror.Append(errs, ctx.DeleteSecContext()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(err, ctx.DeleteSecContext()) } if input, err = c.lib.MakeBufferBytes(key); err != nil { - errs = multierror.Append(errs, err) - errs = multierror.Append(errs, ctx.DeleteSecContext()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(err, ctx.DeleteSecContext()) } - defer input.Release() + + defer func() { + err = multierror.Append(err, input.Release()).ErrorOrNil() + }() } - expiry := time.Unix(int64(tkey.Expiration), 0) + expiry = time.Unix(int64(tkey.Expiration), 0) c.m.Lock() defer c.m.Unlock() @@ -249,8 +266,7 @@ func (c *Client) NegotiateContext(host string) (string, time.Time, error) { // credentials. // It returns the negotiated TKEY name, expiration time, and any error that // occurred. -func (c *Client) NegotiateContextWithCredentials(host, domain, username, password string) (string, time.Time, error) { - +func (c *Client) NegotiateContextWithCredentials(_, _, _, _ string) (string, time.Time, error) { return "", time.Time{}, errNotSupported } @@ -259,8 +275,7 @@ func (c *Client) NegotiateContextWithCredentials(host, domain, username, passwor // keytab. // It returns the negotiated TKEY name, expiration time, and any error that // occurred. -func (c *Client) NegotiateContextWithKeytab(host, domain, username, path string) (string, time.Time, error) { - +func (c *Client) NegotiateContextWithKeytab(_, _, _, _ string) (string, time.Time, error) { return "", time.Time{}, errNotSupported } @@ -268,13 +283,12 @@ func (c *Client) NegotiateContextWithKeytab(host, domain, username, path string) // TKEY name. // It returns any error that occurred. func (c *Client) DeleteContext(keyname string) error { - c.m.Lock() defer c.m.Unlock() ctx, ok := c.ctx[keyname] if !ok { - return errors.New("No such context") + return errNoSuchContext } if err := ctx.DeleteSecContext(); err != nil { diff --git a/gss/apcera_test.go b/gss/apcera_test.go index 8656fc5..f7bd073 100644 --- a/gss/apcera_test.go +++ b/gss/apcera_test.go @@ -1,24 +1,31 @@ //go:build !windows && apcera // +build !windows,apcera -package gss +package gss_test import ( "testing" + "github.com/bodgit/tsig/gss" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) func TestExchangeCredentials(t *testing.T) { - assert.Equal(t, errNotSupported, testExchangeCredentials(t)) + t.Parallel() + + assert.ErrorIs(t, testExchangeCredentials(t), gss.ErrNotSupported) } func TestExchangeKeytab(t *testing.T) { - assert.Equal(t, errNotSupported, testExchangeKeytab(t)) + t.Parallel() + + assert.ErrorIs(t, testExchangeKeytab(t), gss.ErrNotSupported) } func TestNewClientWithConfig(t *testing.T) { - _, err := NewClient(new(dns.Client), WithConfig("")) + t.Parallel() + + _, err := gss.NewClient(new(dns.Client), gss.WithConfig("")) assert.NotNil(t, err) } diff --git a/gss/export_test.go b/gss/export_test.go new file mode 100644 index 0000000..f4f1479 --- /dev/null +++ b/gss/export_test.go @@ -0,0 +1,3 @@ +package gss + +var ErrNotSupported = errNotSupported diff --git a/gss/gokrb5.go b/gss/gokrb5.go index 1529cc9..eec7bc0 100644 --- a/gss/gokrb5.go +++ b/gss/gokrb5.go @@ -5,7 +5,6 @@ package gss import ( "encoding/hex" - "errors" "net" "sync" "time" @@ -28,10 +27,11 @@ type Client struct { logger logr.Logger } -// WithConfig sets the Kerberos configuration used +// WithConfig sets the Kerberos configuration used. func WithConfig(config string) func(*Client) error { return func(c *Client) error { c.config = config + return nil } } @@ -40,7 +40,6 @@ func WithConfig(config string) func(*Client) error { // It returns a context handle for any further functions along with any error // that occurred. func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, error) { - client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -65,7 +64,6 @@ func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, // necessary. // It returns any error that occurred. func (c *Client) Close() error { - return c.close() } @@ -96,7 +94,6 @@ func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { // for this context. // It returns any error that occurred. func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { return dns.ErrKeyAlg } @@ -130,7 +127,10 @@ func (c *Client) negotiateContext(host string, options []wrapper.Option[wrapper. return "", time.Time{}, err } - keyname := generateTKEYName(hostname) + keyname, err := generateTKEYName(hostname) + if err != nil { + return "", time.Time{}, err + } spn := generateSPN(hostname) @@ -151,7 +151,7 @@ func (c *Client) negotiateContext(host string, options []wrapper.Option[wrapper. } if tkey.Header().Name != keyname { - return "", time.Time{}, errors.New("TKEY name does not match") + return "", time.Time{}, errDoesNotMatch } var input []byte @@ -218,16 +218,17 @@ func (c *Client) NegotiateContextWithKeytab(host, domain, username, path string) // TKEY name. // It returns any error that occurred. func (c *Client) DeleteContext(keyname string) error { - c.m.Lock() defer c.m.Unlock() ctx, ok := c.ctx[keyname] if !ok { - return errors.New("No such context") + return errNoSuchContext } - ctx.Close() + if err := ctx.Close(); err != nil { + return err + } delete(c.ctx, keyname) diff --git a/gss/gokrb5_test.go b/gss/gokrb5_test.go index 7c7a8be..04e0de8 100644 --- a/gss/gokrb5_test.go +++ b/gss/gokrb5_test.go @@ -1,24 +1,31 @@ //go:build !windows && !apcera // +build !windows,!apcera -package gss +package gss_test import ( "testing" + "github.com/bodgit/tsig/gss" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) func TestExchangeCredentials(t *testing.T) { + t.Parallel() + assert.Nil(t, testExchangeCredentials(t)) } func TestExchangeKeytab(t *testing.T) { + t.Parallel() + assert.Nil(t, testExchangeKeytab(t)) } func TestNewClientWithConfig(t *testing.T) { - _, err := NewClient(new(dns.Client), WithConfig("")) + t.Parallel() + + _, err := gss.NewClient(new(dns.Client), gss.WithConfig("")) assert.Nil(t, err) } diff --git a/gss/gss.go b/gss/gss.go index f2e3e8c..cd7f658 100644 --- a/gss/gss.go +++ b/gss/gss.go @@ -5,66 +5,66 @@ require "Secure only" updates. Example client: - import ( - "fmt" - "time" - - "github.com/bodgit/tsig" - "github.com/bodgit/tsig/gss" - "github.com/miekg/dns" - ) - - func main() { - dnsClient := new(dns.Client) - dnsClient.Net = "tcp" - - gssClient, err := gss.NewClient(dnsClient) - if err != nil { - panic(err) - } - defer gssClient.Close() - - host := "ns.example.com:53" - - // Negotiate a context with the chosen server using the - // current user. See also - // gssClient.NegotiateContextWithCredentials() and - // gssClient.NegotiateContextWithKeytab() for alternatives - keyname, _, err := gssClient.NegotiateContext(host) - if err != nil { - panic(err) - } - - dnsClient.TsigProvider = gssClient - - // Use the DNS client as normal - - msg := new(dns.Msg) - msg.SetUpdate(dns.Fqdn("example.com")) - - insert, err := dns.NewRR("test.example.com. 300 A 192.0.2.1") - if err != nil { - panic(err) - } - msg.Insert([]dns.RR{insert}) - - msg.SetTsig(keyname, tsig.GSS, 300, time.Now().Unix()) - - rr, _, err := dnsClient.Exchange(msg, host) - if err != nil { - panic(err) - } - - if rr.Rcode != dns.RcodeSuccess { - fmt.Printf("DNS error: %s (%d)\n", dns.RcodeToString[rr.Rcode], rr.Rcode) - } - - // Cleanup the context - err = gssClient.DeleteContext(keyname) - if err != nil { - panic(err) - } - } + import ( + "fmt" + "time" + + "github.com/bodgit/tsig" + "github.com/bodgit/tsig/gss" + "github.com/miekg/dns" + ) + + func main() { + dnsClient := new(dns.Client) + dnsClient.Net = "tcp" + + gssClient, err := gss.NewClient(dnsClient) + if err != nil { + panic(err) + } + defer gssClient.Close() + + host := "ns.example.com:53" + + // Negotiate a context with the chosen server using the + // current user. See also + // gssClient.NegotiateContextWithCredentials() and + // gssClient.NegotiateContextWithKeytab() for alternatives + keyname, _, err := gssClient.NegotiateContext(host) + if err != nil { + panic(err) + } + + dnsClient.TsigProvider = gssClient + + // Use the DNS client as normal + + msg := new(dns.Msg) + msg.SetUpdate(dns.Fqdn("example.com")) + + insert, err := dns.NewRR("test.example.com. 300 A 192.0.2.1") + if err != nil { + panic(err) + } + msg.Insert([]dns.RR{insert}) + + msg.SetTsig(keyname, tsig.GSS, 300, time.Now().Unix()) + + rr, _, err := dnsClient.Exchange(msg, host) + if err != nil { + panic(err) + } + + if rr.Rcode != dns.RcodeSuccess { + fmt.Printf("DNS error: %s (%d)\n", dns.RcodeToString[rr.Rcode], rr.Rcode) + } + + // Cleanup the context + err = gssClient.DeleteContext(keyname) + if err != nil { + panic(err) + } + } Under the hood, GSSAPI is used on platforms other than Windows whilst Windows uses native SSPI which has a similar API. @@ -72,10 +72,10 @@ uses native SSPI which has a similar API. package gss import ( + "crypto/rand" "errors" "fmt" - "math/rand" - "time" + "math/big" "github.com/bodgit/tsig" "github.com/go-logr/logr" @@ -84,7 +84,9 @@ import ( ) var ( - errNotSupported = errors.New("not supported") + errNotSupported = errors.New("not supported") //nolint:nolintlint,unused + errDoesNotMatch = errors.New("TKEY name does not match") + errNoSuchContext = errors.New("no such context") ) // gssNoVerify is a dns.TsigProvider that skips any GSS-TSIG verification. @@ -99,6 +101,7 @@ func (*gssNoVerify) Generate(_ []byte, t *dns.TSIG) ([]byte, error) { if dns.CanonicalName(t.Algorithm) != tsig.GSS { return nil, dns.ErrKeyAlg } + return nil, dns.ErrSecret } @@ -106,19 +109,20 @@ func (*gssNoVerify) Verify(_ []byte, t *dns.TSIG) error { if dns.CanonicalName(t.Algorithm) != tsig.GSS { return dns.ErrKeyAlg } + return nil } -func generateTKEYName(host string) string { - - seed := rand.NewSource(time.Now().UnixNano()) - rng := rand.New(seed) +func generateTKEYName(host string) (string, error) { + i, err := rand.Int(rand.Reader, big.NewInt(0x7fffffff)) + if err != nil { + return "", err + } - return dns.Fqdn(fmt.Sprintf("%d.sig-%s", rng.Int31(), host)) + return dns.Fqdn(fmt.Sprintf("%d.sig-%s", i.Int64(), host)), nil } func generateSPN(host string) string { - if dns.IsFqdn(host) { return fmt.Sprintf("DNS/%s", host[:len(host)-1]) } @@ -127,12 +131,13 @@ func generateSPN(host string) string { } func (c *Client) close() error { - c.m.RLock() + keys := make([]string, 0, len(c.ctx)) for k := range c.ctx { keys = append(keys, k) } + c.m.RUnlock() var errs error @@ -149,23 +154,25 @@ func (c *Client) setOption(options ...func(*Client) error) error { return err } } + return nil } -// SetConfig sets the Kerberos configuration used by c +// SetConfig sets the Kerberos configuration used by c. func (c *Client) SetConfig(config string) error { return c.setOption(WithConfig(config)) } -// WithLogger sets the logger used +// WithLogger sets the logger used. func WithLogger(logger logr.Logger) func(*Client) error { return func(c *Client) error { - c.logger = logger + c.logger = logger.WithName("client") + return nil } } -// SetLogger sets the logger used by c +// SetLogger sets the logger used by c. func (c *Client) SetLogger(logger logr.Logger) error { return c.setOption(WithLogger(logger)) } diff --git a/gss/gss_internal_test.go b/gss/gss_internal_test.go new file mode 100644 index 0000000..2aa3840 --- /dev/null +++ b/gss/gss_internal_test.go @@ -0,0 +1,26 @@ +package gss + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateTKEYName(t *testing.T) { + t.Parallel() + + tkey, err := generateTKEYName("host.example.com") + assert.Nil(t, err) + assert.Regexp(t, regexp.MustCompile(`^\d+\.sig-host\.example\.com\.$`), tkey) +} + +func TestGenerateSPN(t *testing.T) { + t.Parallel() + + spn := generateSPN("host.example.com") + assert.Equal(t, "DNS/host.example.com", spn) + + spn = generateSPN("host.example.com.") + assert.Equal(t, "DNS/host.example.com", spn) +} diff --git a/gss/gss_test.go b/gss/gss_test.go index 824f689..305b87d 100644 --- a/gss/gss_test.go +++ b/gss/gss_test.go @@ -1,84 +1,108 @@ -package gss +package gss_test import ( "fmt" "net" "os" - "regexp" + "runtime" "testing" "time" "github.com/bodgit/tsig" + "github.com/bodgit/tsig/gss" "github.com/go-logr/logr" "github.com/go-logr/logr/testr" + multierror "github.com/hashicorp/go-multierror" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) -func TestGenerateTKEYName(t *testing.T) { - - tkey := generateTKEYName("host.example.com") - assert.Regexp(t, regexp.MustCompile("^\\d+\\.sig-host\\.example\\.com\\.$"), tkey) -} - -func TestGenerateSPN(t *testing.T) { - - spn := generateSPN("host.example.com") - assert.Equal(t, "DNS/host.example.com", spn) - - spn = generateSPN("host.example.com.") - assert.Equal(t, "DNS/host.example.com", spn) -} +const dnsClientTransport = "tcp" func testEnvironmentVariables(t *testing.T) (string, string, string, string, string, string) { - host, ok := os.LookupEnv("DNS_HOST") - if !ok { - t.Fatal("$DNS_HOST not set") - } - - port, ok := os.LookupEnv("DNS_PORT") - if !ok { - port = "53" - } - - realm, ok := os.LookupEnv("DNS_REALM") - if !ok { - t.Fatal("$DNS_REALM not set") - } - - username, ok := os.LookupEnv("DNS_USERNAME") - if !ok { - t.Fatal("$DNS_USERNAME not set") - } - - password, ok := os.LookupEnv("DNS_PASSWORD") - if !ok { - t.Fatal("$DNS_PASSWORD not set") - } - - keytab, ok := os.LookupEnv("DNS_KEYTAB") - if !ok { - t.Fatal("$DNS_KEYTAB not set") + t.Helper() + + var ( + host string + port = "53" + realm string + username string + password string + keytab string + errs *multierror.Error + ) + + for _, env := range []struct { + ptr *string + name string + optional bool + }{ + { + &host, + "DNS_HOST", + false, + }, + { + &port, + "DNS_PORT", + true, + }, + { + &realm, + "DNS_REALM", + false, + }, + { + &username, + "DNS_USERNAME", + false, + }, + { + &password, + "DNS_PASSWORD", + false, + }, + { + &keytab, + "DNS_KEYTAB", + runtime.GOOS == "windows", + }, + } { + if v, ok := os.LookupEnv(env.name); ok { + *env.ptr = v + } else if !env.optional { + errs = multierror.Append(errs, fmt.Errorf("%s is not set", env.name)) + } + } + + if errs.ErrorOrNil() != nil { + t.Fatal(errs) } return host, port, realm, username, password, keytab } -func testExchange(t *testing.T) error { +func testExchange(t *testing.T) (err error) { + t.Helper() + if testing.Short() { t.Skip("skipping integration test") } + //nolint:dogsled host, port, _, _, _, _ := testEnvironmentVariables(t) dnsClient := new(dns.Client) - dnsClient.Net = "tcp" + dnsClient.Net = dnsClientTransport - gssClient, err := NewClient(dnsClient, WithLogger(testr.New(t))) + gssClient, err := gss.NewClient(dnsClient, gss.WithLogger(testr.New(t))) if err != nil { return err } - defer gssClient.Close() + + defer func() { + err = multierror.Append(err, gssClient.Close()).ErrorOrNil() + }() keyname, _, err := gssClient.NegotiateContext(net.JoinHostPort(host, port)) if err != nil { @@ -92,30 +116,28 @@ func testExchange(t *testing.T) error { insert, err := dns.NewRR("test.example.com. 300 A 192.0.2.1") if err != nil { - panic(err) + return err } + msg.Insert([]dns.RR{insert}) msg.SetTsig(keyname, tsig.GSS, 300, time.Now().Unix()) rr, _, err := dnsClient.Exchange(msg, net.JoinHostPort(host, port)) if err != nil { - panic(err) + return err } if rr.Rcode != dns.RcodeSuccess { return fmt.Errorf("DNS error: %s (%d)", dns.RcodeToString[rr.Rcode], rr.Rcode) } - err = gssClient.DeleteContext(keyname) - if err != nil { - return err - } - - return nil + return gssClient.DeleteContext(keyname) } -func testExchangeCredentials(t *testing.T) error { +func testExchangeCredentials(t *testing.T) (err error) { + t.Helper() + if testing.Short() { t.Skip("skipping integration test") } @@ -123,28 +145,32 @@ func testExchangeCredentials(t *testing.T) error { host, port, realm, username, password, _ := testEnvironmentVariables(t) dnsClient := new(dns.Client) - dnsClient.Net = "tcp" + dnsClient.Net = dnsClientTransport - gssClient, err := NewClient(dnsClient, WithLogger(testr.New(t))) + gssClient, err := gss.NewClient(dnsClient) if err != nil { return err } - defer gssClient.Close() - keyname, _, err := gssClient.NegotiateContextWithCredentials(net.JoinHostPort(host, port), realm, username, password) - if err != nil { + defer func() { + err = multierror.Append(err, gssClient.Close()).ErrorOrNil() + }() + + if err = gssClient.SetLogger(testr.New(t)); err != nil { return err } - err = gssClient.DeleteContext(keyname) + keyname, _, err := gssClient.NegotiateContextWithCredentials(net.JoinHostPort(host, port), realm, username, password) if err != nil { return err } - return nil + return gssClient.DeleteContext(keyname) } -func testExchangeKeytab(t *testing.T) error { +func testExchangeKeytab(t *testing.T) (err error) { + t.Helper() + if testing.Short() { t.Skip("skipping integration test") } @@ -152,32 +178,34 @@ func testExchangeKeytab(t *testing.T) error { host, port, realm, username, _, keytab := testEnvironmentVariables(t) dnsClient := new(dns.Client) - dnsClient.Net = "tcp" + dnsClient.Net = dnsClientTransport - gssClient, err := NewClient(dnsClient, WithLogger(testr.New(t))) + gssClient, err := gss.NewClient(dnsClient, gss.WithLogger(testr.New(t))) if err != nil { return err } - defer gssClient.Close() - keyname, _, err := gssClient.NegotiateContextWithKeytab(net.JoinHostPort(host, port), realm, username, keytab) - if err != nil { - return err - } + defer func() { + err = multierror.Append(err, gssClient.Close()).ErrorOrNil() + }() - err = gssClient.DeleteContext(keyname) + keyname, _, err := gssClient.NegotiateContextWithKeytab(net.JoinHostPort(host, port), realm, username, keytab) if err != nil { return err } - return nil + return gssClient.DeleteContext(keyname) } func TestExchange(t *testing.T) { + t.Parallel() + assert.Nil(t, testExchange(t)) } func TestNewClientWithLogger(t *testing.T) { - _, err := NewClient(new(dns.Client), WithLogger(logr.Discard())) + t.Parallel() + + _, err := gss.NewClient(new(dns.Client), gss.WithLogger(logr.Discard())) assert.Nil(t, err) } diff --git a/gss/sspi.go b/gss/sspi.go index 51d5146..e513087 100644 --- a/gss/sspi.go +++ b/gss/sspi.go @@ -5,7 +5,6 @@ package gss import ( "encoding/hex" - "errors" "net" "sync" "time" @@ -28,18 +27,17 @@ type Client struct { logger logr.Logger } -// WithConfig sets the Kerberos configuration used -func WithConfig(config string) func(*Client) error { +// WithConfig sets the Kerberos configuration used. +func WithConfig(_ string) func(*Client) error { return func(c *Client) error { return errNotSupported } } -// New performs any library initialization necessary. +// NewClient performs any library initialization necessary. // It returns a context handle for any further functions along with any error // that occurred. func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, error) { - client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -64,7 +62,6 @@ func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, // necessary. // It returns any error that occurred. func (c *Client) Close() error { - return c.close() } @@ -74,7 +71,6 @@ func (c *Client) Close() error { // for this context. // It returns the bytes for the TSIG MAC and any error that occurred. func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { return nil, dns.ErrKeyAlg } @@ -101,7 +97,6 @@ func (c *Client) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { // for this context. // It returns any error that occurred. func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { - if dns.CanonicalName(t.Algorithm) != tsig.GSS { return dns.ErrKeyAlg } @@ -127,50 +122,43 @@ func (c *Client) Verify(stripped []byte, t *dns.TSIG) error { } func (c *Client) negotiateContext(host string, creds *sspi.Credentials) (string, time.Time, error) { - hostname, _, err := net.SplitHostPort(host) if err != nil { return "", time.Time{}, err } - keyname := generateTKEYName(hostname) + keyname, err := generateTKEYName(hostname) + if err != nil { + return "", time.Time{}, err + } ctx, output, err := negotiate.NewClientContext(creds, generateSPN(hostname)) if err != nil { return "", time.Time{}, err } - var completed bool - var tkey *dns.TKEY + var ( + completed bool + tkey *dns.TKEY + ) for ok := false; !ok; ok = completed { - - var errs error - - // We don't care about non-TKEY answers, no additional RR's to send, and no signing + //nolint:lll if tkey, _, err = util.ExchangeTKEY(c.client, host, keyname, tsig.GSS, util.TkeyModeGSS, 3600, output, nil, "", ""); err != nil { - errs = multierror.Append(errs, err) - errs = multierror.Append(errs, ctx.Release()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(err, ctx.Release()) } if tkey.Header().Name != keyname { - errs = multierror.Append(errs, errors.New("TKEY name does not match")) - errs = multierror.Append(errs, ctx.Release()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(errDoesNotMatch, ctx.Release()) } input, err := hex.DecodeString(tkey.Key) if err != nil { - errs = multierror.Append(errs, err) - errs = multierror.Append(errs, ctx.Release()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(err, ctx.Release()) } if completed, output, err = ctx.Update(input); err != nil { - errs = multierror.Append(errs, err) - errs = multierror.Append(errs, ctx.Release()) - return "", time.Time{}, errs + return "", time.Time{}, multierror.Append(err, ctx.Release()) } } @@ -188,13 +176,15 @@ func (c *Client) negotiateContext(host string, creds *sspi.Credentials) (string, // server to establish a security context using the current user. // It returns the negotiated TKEY name, expiration time, and any error that // occurred. -func (c *Client) NegotiateContext(host string) (string, time.Time, error) { - +func (c *Client) NegotiateContext(host string) (keyname string, expiry time.Time, err error) { creds, err := negotiate.AcquireCurrentUserCredentials() if err != nil { return "", time.Time{}, err } - defer creds.Release() + + defer func() { + err = multierror.Append(err, creds.Release()).ErrorOrNil() + }() return c.negotiateContext(host, creds) } @@ -204,13 +194,17 @@ func (c *Client) NegotiateContext(host string) (string, time.Time, error) { // credentials. // It returns the negotiated TKEY name, expiration time, and any error that // occurred. -func (c *Client) NegotiateContextWithCredentials(host, domain, username, password string) (string, time.Time, error) { - +// +//nolint:lll +func (c *Client) NegotiateContextWithCredentials(host, domain, username, password string) (keyname string, expiry time.Time, err error) { creds, err := negotiate.AcquireUserCredentials(domain, username, password) if err != nil { return "", time.Time{}, err } - defer creds.Release() + + defer func() { + err = multierror.Append(err, creds.Release()).ErrorOrNil() + }() return c.negotiateContext(host, creds) } @@ -220,8 +214,7 @@ func (c *Client) NegotiateContextWithCredentials(host, domain, username, passwor // keytab. // It returns the negotiated TKEY name, expiration time, and any error that // occurred. -func (c *Client) NegotiateContextWithKeytab(host, domain, username, path string) (string, time.Time, error) { - +func (c *Client) NegotiateContextWithKeytab(_, _, _, _ string) (string, time.Time, error) { return "", time.Time{}, errNotSupported } @@ -229,13 +222,12 @@ func (c *Client) NegotiateContextWithKeytab(host, domain, username, path string) // TKEY name. // It returns any error that occurred. func (c *Client) DeleteContext(keyname string) error { - c.m.Lock() defer c.m.Unlock() ctx, ok := c.ctx[keyname] if !ok { - return errors.New("No such context") + return errNoSuchContext } if err := ctx.Release(); err != nil { diff --git a/gss/sspi_test.go b/gss/sspi_test.go new file mode 100644 index 0000000..acd7829 --- /dev/null +++ b/gss/sspi_test.go @@ -0,0 +1,31 @@ +//go:build windows +// +build windows + +package gss_test + +import ( + "testing" + + "github.com/bodgit/tsig/gss" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestExchangeCredentials(t *testing.T) { + t.Parallel() + + assert.Nil(t, testExchangeCredentials(t)) +} + +func TestExchangeKeytab(t *testing.T) { + t.Parallel() + + assert.ErrorIs(t, testExchangeKeytab(t), gss.ErrNotSupported) +} + +func TestNewClientWithConfig(t *testing.T) { + t.Parallel() + + _, err := gss.NewClient(new(dns.Client), gss.WithConfig("")) + assert.NotNil(t, err) +} diff --git a/hmac.go b/hmac.go index ffee12c..1395755 100644 --- a/hmac.go +++ b/hmac.go @@ -2,8 +2,8 @@ package tsig import ( "crypto/hmac" - "crypto/md5" - "crypto/sha1" + "crypto/md5" //nolint:gosec + "crypto/sha1" //nolint:gosec "crypto/sha256" "crypto/sha512" "encoding/base64" @@ -23,6 +23,7 @@ func fromBase64(s []byte) (buf []byte, err error) { buf = make([]byte, buflen) n, err := base64.StdEncoding.Decode(buf, s) buf = buf[:n] + return } @@ -31,6 +32,7 @@ func fromBase64(s []byte) (buf []byte, err error) { // It returns the bytes for the TSIG MAC and any error that occurred. func (h HMAC) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { var f func() hash.Hash + switch dns.CanonicalName(t.Algorithm) { case dns.HmacMD5: f = md5.New @@ -47,16 +49,20 @@ func (h HMAC) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { default: return nil, dns.ErrKeyAlg } + secret, ok := h[t.Hdr.Name] if !ok { return nil, dns.ErrSecret } + rawsecret, err := fromBase64([]byte(secret)) if err != nil { return nil, err } + m := hmac.New(f, rawsecret) m.Write(msg) + return m.Sum(nil), nil } @@ -68,12 +74,15 @@ func (h HMAC) Verify(msg []byte, t *dns.TSIG) error { if err != nil { return err } + mac, err := hex.DecodeString(t.MAC) if err != nil { return err } + if !hmac.Equal(b, mac) { return dns.ErrSig } + return nil } diff --git a/hmac_test.go b/hmac_test.go index 1fa3c50..9a756fd 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -1,24 +1,30 @@ -package tsig +package tsig_test import ( "encoding/base64" "encoding/hex" "testing" + "github.com/bodgit/tsig" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) +//nolint:funlen func TestHMACGenerate(t *testing.T) { - tables := map[string]struct { - provider HMAC + t.Parallel() + + tables := []struct { + name string + provider tsig.HMAC msg []byte tsig *dns.TSIG b []byte err error }{ - "md5": { - HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, + { + "md5", + tsig.HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -26,11 +32,15 @@ func TestHMACGenerate(t *testing.T) { }, Algorithm: dns.HmacMD5, }, - []byte{0xb, 0x78, 0x2f, 0xf6, 0xac, 0xb3, 0xf6, 0xbe, 0x52, 0xdb, 0x22, 0xc7, 0xce, 0x8, 0x11, 0x77}, + []byte{ + 0x0b, 0x78, 0x2f, 0xf6, 0xac, 0xb3, 0xf6, 0xbe, + 0x52, 0xdb, 0x22, 0xc7, 0xce, 0x08, 0x11, 0x77, + }, nil, }, - "sha1": { - HMAC{"example.": "dZFRPtLqbQXGs7SdraTJJSGNSCU="}, + { + "sha1", + tsig.HMAC{"example.": "dZFRPtLqbQXGs7SdraTJJSGNSCU="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -38,11 +48,16 @@ func TestHMACGenerate(t *testing.T) { }, Algorithm: dns.HmacSHA1, }, - []byte{0xb8, 0xb5, 0xdf, 0xd4, 0x27, 0x85, 0x7, 0x6f, 0x2f, 0x3a, 0xa9, 0xc6, 0xf9, 0xfe, 0x98, 0x68, 0xc5, 0xbd, 0x9b, 0x7a}, + []byte{ + 0xb8, 0xb5, 0xdf, 0xd4, 0x27, 0x85, 0x07, 0x6f, + 0x2f, 0x3a, 0xa9, 0xc6, 0xf9, 0xfe, 0x98, 0x68, + 0xc5, 0xbd, 0x9b, 0x7a, + }, nil, }, - "sha224": { - HMAC{"example.": "NaDGqfyc2/Fc0muCPB78CyGPlveTursOxrPVVQ=="}, + { + "sha224", + tsig.HMAC{"example.": "NaDGqfyc2/Fc0muCPB78CyGPlveTursOxrPVVQ=="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -50,11 +65,17 @@ func TestHMACGenerate(t *testing.T) { }, Algorithm: dns.HmacSHA224, }, - []byte{0xfc, 0x1c, 0xf5, 0xd9, 0x5e, 0x1f, 0xb0, 0xd5, 0xad, 0x2d, 0x53, 0x5a, 0x69, 0x2e, 0x47, 0x5c, 0x3a, 0xa8, 0xed, 0x52, 0x41, 0x4c, 0x71, 0x7d, 0xd9, 0x87, 0x3a, 0xcb}, + []byte{ + 0xfc, 0x1c, 0xf5, 0xd9, 0x5e, 0x1f, 0xb0, 0xd5, + 0xad, 0x2d, 0x53, 0x5a, 0x69, 0x2e, 0x47, 0x5c, + 0x3a, 0xa8, 0xed, 0x52, 0x41, 0x4c, 0x71, 0x7d, + 0xd9, 0x87, 0x3a, 0xcb, + }, nil, }, - "sha256": { - HMAC{"example.": "BduxMlVUsrEhdgfOLKSLhNE4D3qzDx7dwyRjt7+BDNE="}, + { + "sha256", + tsig.HMAC{"example.": "BduxMlVUsrEhdgfOLKSLhNE4D3qzDx7dwyRjt7+BDNE="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -62,11 +83,17 @@ func TestHMACGenerate(t *testing.T) { }, Algorithm: dns.HmacSHA256, }, - []byte{0xdc, 0x76, 0x7, 0x57, 0xa5, 0x92, 0x1, 0x55, 0x1d, 0x57, 0xdc, 0xaf, 0x43, 0x6a, 0x45, 0xdc, 0xec, 0xa9, 0xb7, 0x1b, 0x63, 0x37, 0x63, 0x90, 0x4b, 0x63, 0x5d, 0xc3, 0x96, 0xeb, 0x42, 0xd6}, + []byte{ + 0xdc, 0x76, 0x07, 0x57, 0xa5, 0x92, 0x01, 0x55, + 0x1d, 0x57, 0xdc, 0xaf, 0x43, 0x6a, 0x45, 0xdc, + 0xec, 0xa9, 0xb7, 0x1b, 0x63, 0x37, 0x63, 0x90, + 0x4b, 0x63, 0x5d, 0xc3, 0x96, 0xeb, 0x42, 0xd6, + }, nil, }, - "sha384": { - HMAC{"example.": "xqbc2K8kfLDw3yNOOw9kloxrLPX0ILoGK4sxZwVOgDnGzcp9DZu5nDQMZBofAIYf"}, + { + "sha384", + tsig.HMAC{"example.": "xqbc2K8kfLDw3yNOOw9kloxrLPX0ILoGK4sxZwVOgDnGzcp9DZu5nDQMZBofAIYf"}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -74,11 +101,19 @@ func TestHMACGenerate(t *testing.T) { }, Algorithm: dns.HmacSHA384, }, - []byte{0x21, 0x29, 0xfa, 0x1c, 0x10, 0x4b, 0x12, 0x81, 0x95, 0x98, 0x36, 0x5a, 0x92, 0x88, 0x1e, 0x5a, 0x26, 0x76, 0x28, 0x5a, 0xc, 0xe7, 0x53, 0xa5, 0x3c, 0xb6, 0xad, 0x12, 0xc2, 0x7b, 0xb9, 0xd5, 0x88, 0x2f, 0x24, 0xae, 0x39, 0x54, 0xd5, 0xbb, 0x95, 0x7f, 0x30, 0x1c, 0x42, 0x61, 0x22, 0xc5}, + []byte{ + 0x21, 0x29, 0xfa, 0x1c, 0x10, 0x4b, 0x12, 0x81, + 0x95, 0x98, 0x36, 0x5a, 0x92, 0x88, 0x1e, 0x5a, + 0x26, 0x76, 0x28, 0x5a, 0x0c, 0xe7, 0x53, 0xa5, + 0x3c, 0xb6, 0xad, 0x12, 0xc2, 0x7b, 0xb9, 0xd5, + 0x88, 0x2f, 0x24, 0xae, 0x39, 0x54, 0xd5, 0xbb, + 0x95, 0x7f, 0x30, 0x1c, 0x42, 0x61, 0x22, 0xc5, + }, nil, }, - "sha512": { - HMAC{"example.": "WCltYAUyQQjslkIIOXnvJkC3bSlCPEsl6gYEzkIyUbnXbmJZA5PTgSL8fLlwfDKYJl/SiFMTOzQxWvH7AmUvSw=="}, + { + "sha512", + tsig.HMAC{"example.": "WCltYAUyQQjslkIIOXnvJkC3bSlCPEsl6gYEzkIyUbnXbmJZA5PTgSL8fLlwfDKYJl/SiFMTOzQxWvH7AmUvSw=="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -86,23 +121,34 @@ func TestHMACGenerate(t *testing.T) { }, Algorithm: dns.HmacSHA512, }, - []byte{0xdb, 0x3e, 0x97, 0x64, 0x17, 0x8a, 0x93, 0x60, 0x19, 0x6b, 0x80, 0xe4, 0xac, 0xba, 0xbd, 0xb7, 0x1e, 0xe9, 0xb4, 0xf6, 0xc3, 0xe, 0xc0, 0x2c, 0xcd, 0xcf, 0xf3, 0xff, 0x29, 0x8c, 0x3, 0xfa, 0x4b, 0x58, 0xf0, 0xfe, 0xaa, 0x15, 0x6e, 0x77, 0x8f, 0x98, 0x65, 0x72, 0x3c, 0x94, 0x4e, 0x3f, 0xc9, 0xdc, 0x4c, 0x88, 0x7c, 0x4d, 0xfb, 0x23, 0x8a, 0xad, 0xe5, 0x4f, 0xcc, 0x73, 0x50, 0x59}, + []byte{ + 0xdb, 0x3e, 0x97, 0x64, 0x17, 0x8a, 0x93, 0x60, + 0x19, 0x6b, 0x80, 0xe4, 0xac, 0xba, 0xbd, 0xb7, + 0x1e, 0xe9, 0xb4, 0xf6, 0xc3, 0x0e, 0xc0, 0x2c, + 0xcd, 0xcf, 0xf3, 0xff, 0x29, 0x8c, 0x03, 0xfa, + 0x4b, 0x58, 0xf0, 0xfe, 0xaa, 0x15, 0x6e, 0x77, + 0x8f, 0x98, 0x65, 0x72, 0x3c, 0x94, 0x4e, 0x3f, + 0xc9, 0xdc, 0x4c, 0x88, 0x7c, 0x4d, 0xfb, 0x23, + 0x8a, 0xad, 0xe5, 0x4f, 0xcc, 0x73, 0x50, 0x59, + }, nil, }, - "algorithm": { - HMAC{"example.": ""}, + { + "algorithm", + tsig.HMAC{"example.": ""}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ Name: "example.", }, - Algorithm: GSS, + Algorithm: tsig.GSS, }, nil, dns.ErrKeyAlg, }, - "secret": { - HMAC{}, + { + "secret", + tsig.HMAC{}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -113,8 +159,9 @@ func TestHMACGenerate(t *testing.T) { nil, dns.ErrSecret, }, - "garbage": { - HMAC{"example.": "garbage"}, + { + "garbage", + tsig.HMAC{"example.": "garbage"}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -127,8 +174,10 @@ func TestHMACGenerate(t *testing.T) { }, } - for name, table := range tables { - t.Run(name, func(t *testing.T) { + for _, table := range tables { + table := table + t.Run(table.name, func(t *testing.T) { + t.Parallel() b, err := table.provider.Generate(table.msg, table.tsig) assert.Equal(t, table.b, b) assert.Equal(t, table.err, err) @@ -136,39 +185,49 @@ func TestHMACGenerate(t *testing.T) { } } +//nolint:funlen func TestHMACVerify(t *testing.T) { - tables := map[string]struct { - provider HMAC + t.Parallel() + + tables := []struct { + name string + provider tsig.HMAC msg []byte tsig *dns.TSIG err error }{ - "md5": { - HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, + { + "md5", + tsig.HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ Name: "example.", }, Algorithm: dns.HmacMD5, - MAC: hex.EncodeToString([]byte{0xb, 0x78, 0x2f, 0xf6, 0xac, 0xb3, 0xf6, 0xbe, 0x52, 0xdb, 0x22, 0xc7, 0xce, 0x8, 0x11, 0x77}), + MAC: hex.EncodeToString([]byte{ + 0x0b, 0x78, 0x2f, 0xf6, 0xac, 0xb3, 0xf6, 0xbe, + 0x52, 0xdb, 0x22, 0xc7, 0xce, 0x08, 0x11, 0x77, + }), }, nil, }, - "algorithm": { - HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, + { + "algorithm", + tsig.HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ Name: "example.", }, - Algorithm: GSS, + Algorithm: tsig.GSS, MAC: "", }, dns.ErrKeyAlg, }, - "garbage": { - HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, + { + "garbage", + tsig.HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, []byte("message"), &dns.TSIG{ Hdr: dns.RR_Header{ @@ -179,22 +238,28 @@ func TestHMACVerify(t *testing.T) { }, hex.InvalidByteError(0x67), }, - "signature": { - HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, + { + "signature", + tsig.HMAC{"example.": "DRwIYZn6exnhof/mcV/aEQ=="}, []byte("different"), &dns.TSIG{ Hdr: dns.RR_Header{ Name: "example.", }, Algorithm: dns.HmacMD5, - MAC: hex.EncodeToString([]byte{0xb, 0x78, 0x2f, 0xf6, 0xac, 0xb3, 0xf6, 0xbe, 0x52, 0xdb, 0x22, 0xc7, 0xce, 0x8, 0x11, 0x77}), + MAC: hex.EncodeToString([]byte{ + 0x0b, 0x78, 0x2f, 0xf6, 0xac, 0xb3, 0xf6, 0xbe, + 0x52, 0xdb, 0x22, 0xc7, 0xce, 0x08, 0x11, 0x77, + }), }, dns.ErrSig, }, } - for name, table := range tables { - t.Run(name, func(t *testing.T) { + for _, table := range tables { + table := table + t.Run(table.name, func(t *testing.T) { + t.Parallel() err := table.provider.Verify(table.msg, table.tsig) assert.Equal(t, table.err, err) }) diff --git a/internal/util/util.go b/internal/util/util.go index 710244b..2d32dc4 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -1,7 +1,11 @@ +/* +Package util contains utility routines. +*/ package util import ( "encoding/hex" + "errors" "fmt" "time" @@ -12,15 +16,15 @@ import ( const ( _ uint16 = iota // Reserved, RFC 2930, section 2.5 - // TkeyModeServer is used for server assigned keying + // TkeyModeServer is used for server assigned keying. TkeyModeServer - // TkeyModeDH is used for Diffie-Hellman exchanged keying + // TkeyModeDH is used for Diffie-Hellman exchanged keying. TkeyModeDH - // TkeyModeGSS is used for GSS-API establishment + // TkeyModeGSS is used for GSS-API establishment. TkeyModeGSS - // TkeyModeResolver is used for resolver assigned keying + // TkeyModeResolver is used for resolver assigned keying. TkeyModeResolver - // TkeyModeDelete is used for key deletion + // TkeyModeDelete is used for key deletion. TkeyModeDelete ) @@ -33,7 +37,6 @@ type Exchanger interface { // TCP. If the existing network is configured to only use IPv4 or IPv6 then // the appropriate network is chosen to maintain this choice. func CopyDNSClient(dnsClient *dns.Client) (*dns.Client, error) { - client := new(dns.Client) if err := copier.Copy(client, dnsClient); err != nil { return nil, err @@ -60,17 +63,17 @@ func CopyDNSClient(dnsClient *dns.Client) (*dns.Client, error) { } func calculateTimes(mode uint16, lifetime uint32) (uint32, uint32, error) { - switch mode { case TkeyModeDH: fallthrough case TkeyModeGSS: now := time.Now().Unix() + return uint32(now), uint32(now) + lifetime, nil case TkeyModeDelete: return 0, 0, nil default: - return 0, 0, fmt.Errorf("Unsupported TKEY mode %d", mode) + return 0, 0, fmt.Errorf("unsupported TKEY mode %d", mode) } } @@ -80,8 +83,9 @@ func calculateTimes(mode uint16, lifetime uint32) (uint32, uint32, error) { // with TSIG if a key name, algorithm and MAC are provided. // The TKEY record is returned along with any other DNS records in the // response along with any error that occurred. -func ExchangeTKEY(client Exchanger, host, keyname, algorithm string, mode uint16, lifetime uint32, input []byte, extra []dns.RR, tsigname, tsigalgo string) (*dns.TKEY, []dns.RR, error) { - +// +//nolint:cyclop,funlen +func ExchangeTKEY(client Exchanger, host, keyname, algorithm string, mode uint16, lifetime uint32, input []byte, extra []dns.RR, tsigname, tsigalgo string) (*dns.TKEY, []dns.RR, error) { //nolint:lll msg := &dns.Msg{ MsgHdr: dns.MsgHdr{ Id: dns.Id(), @@ -141,8 +145,9 @@ func ExchangeTKEY(client Exchanger, host, keyname, algorithm string, mode uint16 case *dns.TKEY: // There mustn't be more than one TKEY answer RR if tkey != nil { - return nil, nil, fmt.Errorf("Multiple TKEY responses") + return nil, nil, errors.New("multiple TKEY responses") } + tkey = t default: additional = append(additional, ans) @@ -151,7 +156,7 @@ func ExchangeTKEY(client Exchanger, host, keyname, algorithm string, mode uint16 // There should always be at least a TKEY answer RR if tkey == nil { - return nil, nil, fmt.Errorf("Received no TKEY response") + return nil, nil, errors.New("received no TKEY response") } if tkey.Error != 0 { diff --git a/internal/util/util_internal_test.go b/internal/util/util_internal_test.go new file mode 100644 index 0000000..5be41b0 --- /dev/null +++ b/internal/util/util_internal_test.go @@ -0,0 +1,29 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCalculateTimes(t *testing.T) { + t.Parallel() + + lifetime := uint32(3600) + + t0, t1, err := calculateTimes(TkeyModeDH, lifetime) + assert.Nil(t, err) + assert.Equal(t, lifetime, t1-t0) + + t0, t1, err = calculateTimes(TkeyModeGSS, lifetime) + assert.Nil(t, err) + assert.Equal(t, lifetime, t1-t0) + + t0, t1, err = calculateTimes(TkeyModeDelete, lifetime) + assert.Nil(t, err) + assert.Equal(t, uint32(0), t0) + assert.Equal(t, uint32(0), t1) + + _, _, err = calculateTimes(TkeyModeServer, lifetime) + assert.NotNil(t, err) +} diff --git a/internal/util/util_test.go b/internal/util/util_test.go index b3dc131..802d642 100644 --- a/internal/util/util_test.go +++ b/internal/util/util_test.go @@ -1,10 +1,12 @@ -package util +package util_test import ( + "errors" "testing" "time" "github.com/bodgit/tsig" + "github.com/bodgit/tsig/internal/util" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -15,8 +17,7 @@ type FakeClient struct { Err error } -func (c *FakeClient) Exchange(m *dns.Msg, address string) (*dns.Msg, time.Duration, error) { - +func (c *FakeClient) Exchange(_ *dns.Msg, _ string) (*dns.Msg, time.Duration, error) { if c.Err != nil { return nil, 0, c.Err } @@ -24,28 +25,9 @@ func (c *FakeClient) Exchange(m *dns.Msg, address string) (*dns.Msg, time.Durati return c.Msg, c.Duration, nil } -func TestCalculateTimes(t *testing.T) { - - lifetime := uint32(3600) - - t0, t1, err := calculateTimes(TkeyModeDH, lifetime) - assert.Nil(t, err) - assert.Equal(t, lifetime, t1-t0) - - t0, t1, err = calculateTimes(TkeyModeGSS, lifetime) - assert.Nil(t, err) - assert.Equal(t, lifetime, t1-t0) - - t0, t1, err = calculateTimes(TkeyModeDelete, lifetime) - assert.Nil(t, err) - assert.Equal(t, uint32(0), t0) - assert.Equal(t, uint32(0), t1) - - _, _, err = calculateTimes(TkeyModeServer, lifetime) - assert.NotNil(t, err) -} - +//nolint:funlen func TestExchangeTKEY(t *testing.T) { + t.Parallel() now := uint32(time.Now().Unix()) @@ -57,14 +39,15 @@ func TestExchangeTKEY(t *testing.T) { Ttl: 0, }, Algorithm: tsig.GSS, - Mode: TkeyModeGSS, + Mode: util.TkeyModeGSS, Inception: now, Expiration: now + 3600, KeySize: 4, Key: "deadbeef", } - tables := map[string]struct { + tables := []struct { + name string client FakeClient host string keyname string @@ -79,7 +62,8 @@ func TestExchangeTKEY(t *testing.T) { expectedAdditional []dns.RR expectedErr error }{ - "ok": { + { + name: "ok", client: FakeClient{ Msg: &dns.Msg{ Answer: []dns.RR{ @@ -92,7 +76,7 @@ func TestExchangeTKEY(t *testing.T) { host: "ns.example.com.", keyname: "test.example.com.", algorithm: tsig.GSS, - mode: TkeyModeGSS, + mode: util.TkeyModeGSS, lifetime: 3600, expectedTKEY: goodTKEY, expectedAdditional: []dns.RR{}, @@ -100,12 +84,80 @@ func TestExchangeTKEY(t *testing.T) { }, } - for name, table := range tables { - t.Run(name, func(t *testing.T) { - tkey, additional, err := ExchangeTKEY(&table.client, table.host, table.keyname, table.algorithm, table.mode, table.lifetime, table.input, table.extra, table.tsigname, table.tsigalgo) + for _, table := range tables { + table := table + t.Run(table.name, func(t *testing.T) { + t.Parallel() + //nolint:lll + tkey, additional, err := util.ExchangeTKEY(&table.client, table.host, table.keyname, table.algorithm, table.mode, table.lifetime, table.input, table.extra, table.tsigname, table.tsigalgo) assert.Equal(t, table.expectedTKEY, tkey) assert.Equal(t, table.expectedAdditional, additional) assert.Equal(t, table.expectedErr, err) }) } } + +//nolint:funlen +func TestCopyDNSClient(t *testing.T) { + t.Parallel() + + tables := []struct { + name string + client dns.Client + net string + err error + }{ + { + "tcp", + dns.Client{ + Net: "tcp", + }, + "tcp", + nil, + }, + { + "udp", + dns.Client{ + Net: "udp", + }, + "tcp", + nil, + }, + { + "udp4", + dns.Client{ + Net: "udp4", + }, + "tcp4", + nil, + }, + { + "udp6", + dns.Client{ + Net: "udp6", + }, + "tcp6", + nil, + }, + { + "invalid", + dns.Client{ + Net: "invalid", + }, + "tcp6", + errors.New("unsupported transport 'invalid'"), + }, + } + + for _, table := range tables { + table := table + t.Run(table.name, func(t *testing.T) { + t.Parallel() + client, err := util.CopyDNSClient(&table.client) + if table.err == nil { + assert.Equal(t, table.net, client.Net) + } + assert.Equal(t, table.err, err) + }) + } +} diff --git a/multi.go b/multi.go index 4889bd5..0137cf3 100644 --- a/multi.go +++ b/multi.go @@ -1,6 +1,10 @@ package tsig -import "github.com/miekg/dns" +import ( + "errors" + + "github.com/miekg/dns" +) type multiProvider struct { providers []dns.TsigProvider @@ -8,27 +12,21 @@ type multiProvider struct { func (mp *multiProvider) Generate(msg []byte, t *dns.TSIG) (b []byte, err error) { for _, p := range mp.providers { - b, err = p.Generate(msg, t) - switch err { - case dns.ErrKeyAlg: - break - default: + if b, err = p.Generate(msg, t); err == nil || !errors.Is(err, dns.ErrKeyAlg) { return } } + return nil, dns.ErrKeyAlg } func (mp *multiProvider) Verify(msg []byte, t *dns.TSIG) (err error) { for _, p := range mp.providers { - err = p.Verify(msg, t) - switch err { - case dns.ErrKeyAlg: - break - default: + if err = p.Verify(msg, t); err == nil || !errors.Is(err, dns.ErrKeyAlg) { return } } + return dns.ErrKeyAlg } @@ -40,6 +38,7 @@ func (mp *multiProvider) Verify(msg []byte, t *dns.TSIG) (err error) { // returned; it does not continue down the list. func MultiProvider(providers ...dns.TsigProvider) dns.TsigProvider { allProviders := make([]dns.TsigProvider, 0, len(providers)) + for _, p := range providers { if mp, ok := p.(*multiProvider); ok { allProviders = append(allProviders, mp.providers...) @@ -47,5 +46,6 @@ func MultiProvider(providers ...dns.TsigProvider) dns.TsigProvider { allProviders = append(allProviders, p) } } + return &multiProvider{allProviders} } diff --git a/multi_test.go b/multi_test.go index 6f11d53..f88815b 100644 --- a/multi_test.go +++ b/multi_test.go @@ -1,16 +1,17 @@ -package tsig +package tsig_test import ( "errors" "testing" + "github.com/bodgit/tsig" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) var ( errProvider = errors.New("provider error") - testSignature = []byte("a good signature") + testSignature = []byte("a good signature") //nolint:gochecknoglobals ) type unsupportedProvider struct{} @@ -44,40 +45,50 @@ func (testProvider) Verify(_ []byte, _ *dns.TSIG) error { } func TestMultiProviderGenerate(t *testing.T) { - tables := map[string]struct { + t.Parallel() + + tables := []struct { + name string provider dns.TsigProvider signature []byte err error }{ - "good": { - MultiProvider(new(testProvider)), + { + "good", + tsig.MultiProvider(new(testProvider)), testSignature, nil, }, - "unsupported good": { - MultiProvider(new(unsupportedProvider), new(testProvider)), + { + "unsupported good", + tsig.MultiProvider(new(unsupportedProvider), new(testProvider)), testSignature, nil, }, - "error good": { - MultiProvider(new(errorProvider), new(testProvider)), + { + "error good", + tsig.MultiProvider(new(errorProvider), new(testProvider)), nil, errProvider, }, - "all unsupported": { - MultiProvider(new(unsupportedProvider)), + { + "all unsupported", + tsig.MultiProvider(new(unsupportedProvider)), nil, dns.ErrKeyAlg, }, - "nested": { - MultiProvider(MultiProvider(new(testProvider))), + { + "nested", + tsig.MultiProvider(tsig.MultiProvider(new(testProvider))), testSignature, nil, }, } - for name, table := range tables { - t.Run(name, func(t *testing.T) { + for _, table := range tables { + table := table + t.Run(table.name, func(t *testing.T) { + t.Parallel() b, err := table.provider.Generate(nil, nil) assert.Equal(t, table.signature, b) assert.Equal(t, table.err, err) @@ -86,30 +97,39 @@ func TestMultiProviderGenerate(t *testing.T) { } func TestMultiProviderVerify(t *testing.T) { - tables := map[string]struct { + t.Parallel() + + tables := []struct { + name string provider dns.TsigProvider err error }{ - "good": { - MultiProvider(new(testProvider)), + { + "good", + tsig.MultiProvider(new(testProvider)), nil, }, - "unsupported good": { - MultiProvider(new(unsupportedProvider), new(testProvider)), + { + "unsupported good", + tsig.MultiProvider(new(unsupportedProvider), new(testProvider)), nil, }, - "error good": { - MultiProvider(new(errorProvider), new(testProvider)), + { + "error good", + tsig.MultiProvider(new(errorProvider), new(testProvider)), errProvider, }, - "all unsuppored": { - MultiProvider(new(unsupportedProvider)), + { + "all unsuppored", + tsig.MultiProvider(new(unsupportedProvider)), dns.ErrKeyAlg, }, } - for name, table := range tables { - t.Run(name, func(t *testing.T) { + for _, table := range tables { + table := table + t.Run(table.name, func(t *testing.T) { + t.Parallel() err := table.provider.Verify(nil, nil) assert.Equal(t, table.err, err) }) diff --git a/tsig.go b/tsig.go index 3767779..382cb86 100644 --- a/tsig.go +++ b/tsig.go @@ -1,6 +1,9 @@ +/* +Package tsig adds support for additional TSIG methods used in DNS queries. +*/ package tsig const ( - // GSS is the RFC 3645 defined algorithm name + // GSS is the RFC 3645 defined algorithm name. GSS = "gss-tsig." )