From 9f2f2101dbf416e5562b33025dc60d4cc253a3f8 Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Wed, 10 Jan 2018 19:56:16 +0000 Subject: [PATCH 1/4] Add ability to use different TSIG algorithms --- client.go | 58 ++++++++++++++++++------- tsig.go | 128 ++++++++++++++++++++++++++++++++++++------------------ 2 files changed, 128 insertions(+), 58 deletions(-) diff --git a/client.go b/client.go index c52eec1d5..e6abf6747 100644 --- a/client.go +++ b/client.go @@ -24,11 +24,17 @@ const ( dohMimeType = "application/dns-message" ) +type TsigAlgorithm struct { + Generate tsigAlgorithmGenerate + Verify tsigAlgorithmVerify +} + // A Conn represents a connection to a DNS server. type Conn struct { - net.Conn // a net.Conn holding the connection - UDPSize uint16 // minimum receive buffer for UDP messages - TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + net.Conn // a net.Conn holding the connection + UDPSize uint16 // minimum receive buffer for UDP messages + TsigSecret map[string]interface{} // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigAlgorithm map[string]*TsigAlgorithm tsigRequestMAC string } @@ -42,12 +48,13 @@ type Client struct { // WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and // Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext) Timeout time.Duration - DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero - ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero - WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero - HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS - TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) - SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass + DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero + HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS + TsigSecret map[string]interface{} // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigAlgorithm map[string]*TsigAlgorithm + SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight } @@ -194,6 +201,7 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro } co.TsigSecret = c.TsigSecret + co.TsigAlgorithm = c.TsigAlgorithm t := time.Now() // write with the appropriate write timeout co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) @@ -300,11 +308,20 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, err } if t := m.IsTsig(); t != nil { - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return m, ErrSecret + if a, ok := co.TsigAlgorithm[t.Algorithm]; ok { + if a.Verify != nil { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + err = TsigVerifyByAlgorithm(p, a.Verify, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + } + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name].(string), co.tsigRequestMAC, false) } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } return m, err } @@ -437,10 +454,19 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { mac := "" - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return ErrSecret + if a, ok := co.TsigAlgorithm[t.Algorithm]; ok { + if a.Generate != nil { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerateByAlgorithm(m, a.Generate, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + } + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name].(string), co.tsigRequestMAC, false) } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) // Set for the next read, although only used in zone transfers co.tsigRequestMAC = mac } else { diff --git a/tsig.go b/tsig.go index 4837b4ab1..fd7c315eb 100644 --- a/tsig.go +++ b/tsig.go @@ -22,6 +22,9 @@ const ( HmacSHA512 = "hmac-sha512." ) +type tsigAlgorithmGenerate func([]byte, string, interface{}) ([]byte, error) +type tsigAlgorithmVerify func([]byte, *TSIG, interface{}) error + // TSIG is the RR the holds the transaction signature of a message. // See RFC 2845 and RFC 4635. type TSIG struct { @@ -83,6 +86,32 @@ type timerWireFmt struct { Fudge uint16 } +func tsigGenerateHmac(msg []byte, algorithm string, meta interface{}) ([]byte, error) { + secret := meta.(string) + + rawsecret, err := fromBase64([]byte(secret)) + if err != nil { + return nil, err + } + + var h hash.Hash + switch strings.ToLower(algorithm) { + case HmacMD5: + h = hmac.New(md5.New, []byte(rawsecret)) + case HmacSHA1: + h = hmac.New(sha1.New, []byte(rawsecret)) + case HmacSHA256: + h = hmac.New(sha256.New, []byte(rawsecret)) + case HmacSHA512: + h = hmac.New(sha512.New, []byte(rawsecret)) + default: + return nil, ErrKeyAlg + } + h.Write(msg) + + return h.Sum(nil), nil +} + // TsigGenerate fills out the TSIG record attached to the message. // The message should contain // a "stub" TSIG RR with the algorithm, key name (owner name of the RR), @@ -92,14 +121,22 @@ type timerWireFmt struct { // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { + return TsigGenerateByAlgorithm(m, tsigGenerateHmac, secret, requestMAC, timersOnly) +} + +// TsigGenerateByAlgorithm fills out the TSIG record attached to the message +// using a callback to implement the algorithm-specific generation. +// The message should contain +// a "stub" TSIG RR with the algorithm, key name (owner name of the RR), +// time fudge (defaults to 300 seconds) and the current time +// The TSIG MAC is saved in that Tsig RR. +// When TsigGenerate is called for the first time requestMAC is set to the empty string and +// timersOnly is false. +// If something goes wrong an error is returned, otherwise it is nil. +func TsigGenerateByAlgorithm(m *Msg, cb tsigAlgorithmGenerate, meta interface{}, requestMAC string, timersOnly bool) ([]byte, string, error) { if m.IsTsig() == nil { panic("dns: TSIG not last RR in additional") } - // If we barf here, the caller is to blame - rawsecret, err := fromBase64([]byte(secret)) - if err != nil { - return nil, "", err - } rr := m.Extra[len(m.Extra)-1].(*TSIG) m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg @@ -110,21 +147,13 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly) t := new(TSIG) - var h hash.Hash - switch strings.ToLower(rr.Algorithm) { - case HmacMD5: - h = hmac.New(md5.New, []byte(rawsecret)) - case HmacSHA1: - h = hmac.New(sha1.New, []byte(rawsecret)) - case HmacSHA256: - h = hmac.New(sha256.New, []byte(rawsecret)) - case HmacSHA512: - h = hmac.New(sha512.New, []byte(rawsecret)) - default: - return nil, "", ErrKeyAlg + + h, err := cb(buf, rr.Algorithm, meta) + if err != nil { + return nil, "", err } - h.Write(buf) - t.MAC = hex.EncodeToString(h.Sum(nil)) + + t.MAC = hex.EncodeToString(h) t.MACSize = uint16(len(t.MAC) / 2) // Size is half! t.Hdr = RR_Header{Name: rr.Hdr.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0} @@ -146,38 +175,19 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s return mbuf, t.MAC, nil } -// TsigVerify verifies the TSIG on a message. -// If the signature does not validate err contains the -// error, otherwise it is nil. -func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { +func tsigVerifyHmac(msg []byte, tsig *TSIG, meta interface{}) error { + secret := meta.(string) + rawsecret, err := fromBase64([]byte(secret)) if err != nil { return err } - // Strip the TSIG from the incoming msg - stripped, tsig, err := stripTsig(msg) - if err != nil { - return err - } msgMAC, err := hex.DecodeString(tsig.MAC) if err != nil { return err } - buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) - - // Fudge factor works both ways. A message can arrive before it was signed because - // of clock skew. - now := uint64(time.Now().Unix()) - ti := now - tsig.TimeSigned - if now < tsig.TimeSigned { - ti = tsig.TimeSigned - now - } - if uint64(tsig.Fudge) < ti { - return ErrTime - } - var h hash.Hash switch strings.ToLower(tsig.Algorithm) { case HmacMD5: @@ -191,13 +201,47 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { default: return ErrKeyAlg } - h.Write(buf) + h.Write(msg) if !hmac.Equal(h.Sum(nil), msgMAC) { return ErrSig } return nil } +// TsigVerify verifies the TSIG on a message. +// If the signature does not validate err contains the +// error, otherwise it is nil. +func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { + return TsigVerifyByAlgorithm(msg, tsigVerifyHmac, secret, requestMAC, timersOnly) +} + +// TsigVerifyByAlgorithm verifies the TSIG on a message using a callback to +// implement the algorithm-specific verification. +// If the signature does not validate err contains the +// error, otherwise it is nil. +func TsigVerifyByAlgorithm(msg []byte, cb tsigAlgorithmVerify, meta interface{}, requestMAC string, timersOnly bool) error { + // Strip the TSIG from the incoming msg + stripped, tsig, err := stripTsig(msg) + if err != nil { + return err + } + + buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) + + // Fudge factor works both ways. A message can arrive before it was signed because + // of clock skew. + now := uint64(time.Now().Unix()) + ti := now - tsig.TimeSigned + if now < tsig.TimeSigned { + ti = tsig.TimeSigned - now + } + if uint64(tsig.Fudge) < ti { + return ErrTime + } + + return cb(buf, tsig, meta) +} + // Create a wiredata buffer for the MAC calculation. func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte { var buf []byte From 29b6832e218bb67ac9cc532114da267b375a0586 Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Thu, 11 Jan 2018 23:57:23 +1030 Subject: [PATCH 2/4] Restructure pull-request to avoid breaking the API This commit removes the interface{} 'meta' argument, which required breaking the signature of TsigSecret, and instead passes the name directly to the callback. If there is a need to use a non-string secret, this can be looked up from within a closure using the name as a lookup key. This change still permits all the same use-cases as it's parent. --- client.go | 29 ++++++++++++----------------- tsig.go | 29 +++++++++++++++-------------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/client.go b/client.go index e6abf6747..01a40a491 100644 --- a/client.go +++ b/client.go @@ -24,16 +24,11 @@ const ( dohMimeType = "application/dns-message" ) -type TsigAlgorithm struct { - Generate tsigAlgorithmGenerate - Verify tsigAlgorithmVerify -} - // A Conn represents a connection to a DNS server. type Conn struct { - net.Conn // a net.Conn holding the connection - UDPSize uint16 // minimum receive buffer for UDP messages - TsigSecret map[string]interface{} // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + net.Conn // a net.Conn holding the connection + UDPSize uint16 // minimum receive buffer for UDP messages + TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) TsigAlgorithm map[string]*TsigAlgorithm tsigRequestMAC string } @@ -48,11 +43,11 @@ type Client struct { // WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and // Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext) Timeout time.Duration - DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero - ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero - WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero - HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS - TsigSecret map[string]interface{} // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero + HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS + TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) TsigAlgorithm map[string]*TsigAlgorithm SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight @@ -313,14 +308,14 @@ func (co *Conn) ReadMsg() (*Msg, error) { if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return m, ErrSecret } - err = TsigVerifyByAlgorithm(p, a.Verify, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + err = TsigVerifyByAlgorithm(p, a.Verify, t.Hdr.Name, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } } else { if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return m, ErrSecret } // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name].(string), co.tsigRequestMAC, false) + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } } return m, err @@ -459,13 +454,13 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return ErrSecret } - out, mac, err = TsigGenerateByAlgorithm(m, a.Generate, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + out, mac, err = TsigGenerateByAlgorithm(m, a.Generate, t.Hdr.Name, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } } else { if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return ErrSecret } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name].(string), co.tsigRequestMAC, false) + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } // Set for the next read, although only used in zone transfers co.tsigRequestMAC = mac diff --git a/tsig.go b/tsig.go index fd7c315eb..3b2357e45 100644 --- a/tsig.go +++ b/tsig.go @@ -22,8 +22,13 @@ const ( HmacSHA512 = "hmac-sha512." ) -type tsigAlgorithmGenerate func([]byte, string, interface{}) ([]byte, error) -type tsigAlgorithmVerify func([]byte, *TSIG, interface{}) error +type TsigAlgorithm struct { + Generate tsigAlgorithmGenerate + Verify tsigAlgorithmVerify +} + +type tsigAlgorithmGenerate func(msg []byte, algorithm, name, secret string) ([]byte, error) +type tsigAlgorithmVerify func(msg []byte, tsig *TSIG, name, secret string) error // TSIG is the RR the holds the transaction signature of a message. // See RFC 2845 and RFC 4635. @@ -86,9 +91,7 @@ type timerWireFmt struct { Fudge uint16 } -func tsigGenerateHmac(msg []byte, algorithm string, meta interface{}) ([]byte, error) { - secret := meta.(string) - +func tsigGenerateHmac(msg []byte, algorithm string, name, secret string) ([]byte, error) { rawsecret, err := fromBase64([]byte(secret)) if err != nil { return nil, err @@ -121,7 +124,7 @@ func tsigGenerateHmac(msg []byte, algorithm string, meta interface{}) ([]byte, e // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { - return TsigGenerateByAlgorithm(m, tsigGenerateHmac, secret, requestMAC, timersOnly) + return TsigGenerateByAlgorithm(m, tsigGenerateHmac, "", secret, requestMAC, timersOnly) } // TsigGenerateByAlgorithm fills out the TSIG record attached to the message @@ -133,7 +136,7 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s // When TsigGenerate is called for the first time requestMAC is set to the empty string and // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. -func TsigGenerateByAlgorithm(m *Msg, cb tsigAlgorithmGenerate, meta interface{}, requestMAC string, timersOnly bool) ([]byte, string, error) { +func TsigGenerateByAlgorithm(m *Msg, cb tsigAlgorithmGenerate, name, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { if m.IsTsig() == nil { panic("dns: TSIG not last RR in additional") } @@ -148,7 +151,7 @@ func TsigGenerateByAlgorithm(m *Msg, cb tsigAlgorithmGenerate, meta interface{}, t := new(TSIG) - h, err := cb(buf, rr.Algorithm, meta) + h, err := cb(buf, rr.Algorithm, name, secret) if err != nil { return nil, "", err } @@ -175,9 +178,7 @@ func TsigGenerateByAlgorithm(m *Msg, cb tsigAlgorithmGenerate, meta interface{}, return mbuf, t.MAC, nil } -func tsigVerifyHmac(msg []byte, tsig *TSIG, meta interface{}) error { - secret := meta.(string) - +func tsigVerifyHmac(msg []byte, tsig *TSIG, name, secret string) error { rawsecret, err := fromBase64([]byte(secret)) if err != nil { return err @@ -212,14 +213,14 @@ func tsigVerifyHmac(msg []byte, tsig *TSIG, meta interface{}) error { // If the signature does not validate err contains the // error, otherwise it is nil. func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { - return TsigVerifyByAlgorithm(msg, tsigVerifyHmac, secret, requestMAC, timersOnly) + return TsigVerifyByAlgorithm(msg, tsigVerifyHmac, "", secret, requestMAC, timersOnly) } // TsigVerifyByAlgorithm verifies the TSIG on a message using a callback to // implement the algorithm-specific verification. // If the signature does not validate err contains the // error, otherwise it is nil. -func TsigVerifyByAlgorithm(msg []byte, cb tsigAlgorithmVerify, meta interface{}, requestMAC string, timersOnly bool) error { +func TsigVerifyByAlgorithm(msg []byte, cb tsigAlgorithmVerify, name, secret, requestMAC string, timersOnly bool) error { // Strip the TSIG from the incoming msg stripped, tsig, err := stripTsig(msg) if err != nil { @@ -239,7 +240,7 @@ func TsigVerifyByAlgorithm(msg []byte, cb tsigAlgorithmVerify, meta interface{}, return ErrTime } - return cb(buf, tsig, meta) + return cb(buf, tsig, name, secret) } // Create a wiredata buffer for the MAC calculation. From 213000833e69cafecb98f9ab663101ce0b48c14a Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Fri, 12 Jan 2018 01:17:40 +1030 Subject: [PATCH 3/4] Remove redundant string type in tsigGenerateHmac --- tsig.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsig.go b/tsig.go index 3b2357e45..12678edbe 100644 --- a/tsig.go +++ b/tsig.go @@ -91,7 +91,7 @@ type timerWireFmt struct { Fudge uint16 } -func tsigGenerateHmac(msg []byte, algorithm string, name, secret string) ([]byte, error) { +func tsigGenerateHmac(msg []byte, algorithm, name, secret string) ([]byte, error) { rawsecret, err := fromBase64([]byte(secret)) if err != nil { return nil, err From 2450047b0e050b2122520a70dc5c853a5b5bbdde Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Wed, 17 Jan 2018 22:02:25 +0000 Subject: [PATCH 4/4] Add custom TSIG algorithms to server too --- server.go | 51 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/server.go b/server.go index 2d98f1488..992e0b638 100644 --- a/server.go +++ b/server.go @@ -55,11 +55,12 @@ type response struct { tsigStatus error tsigTimersOnly bool tsigRequestMAC string - tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used - tcp net.Conn // i/o connection if TCP was used - udpSession *SessionUDP // oob data to get egress interface right - writer Writer // writer to output the raw DNS bits + tsigSecret map[string]string // the tsig secrets + tsigAlgorithm map[string]*TsigAlgorithm // custom tsig algorithms + udp *net.UDPConn // i/o connection if UDP was used + tcp net.Conn // i/o connection if TCP was used + udpSession *SessionUDP // oob data to get egress interface right + writer Writer // writer to output the raw DNS bits } // ServeMux is an DNS request multiplexer. It matches the @@ -294,6 +295,8 @@ type Server struct { IdleTimeout func() time.Duration // Secret(s) for Tsig map[]. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). TsigSecret map[string]string + // Non-HMAC Tsig algorithm callbacks map[]{generate(), verify()}. + TsigAlgorithm map[string]*TsigAlgorithm // Unsafe instructs the server to disregard any sanity checks and directly hand the message to // the handler. It will specifically not check if the query has the QR bit not set. Unsafe bool @@ -634,12 +637,22 @@ func (srv *Server) serveDNS(w *response) { } w.tsigStatus = nil - if w.tsigSecret != nil { + if w.tsigAlgorithm != nil || w.tsigSecret != nil { if t := req.IsTsig(); t != nil { - if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { - w.tsigStatus = TsigVerify(w.msg, secret, "", false) + if a, ok := w.tsigAlgorithm[t.Algorithm]; ok { + if a.Verify != nil { + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerifyByAlgorithm(w.msg, a.Verify, t.Hdr.Name, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } + } } else { - w.tsigStatus = ErrSecret + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerify(w.msg, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } } w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC @@ -703,11 +716,23 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { var data []byte - if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) + if w.tsigAlgorithm != nil || w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { - data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) - if err != nil { - return err + if a, ok := w.tsigAlgorithm[t.Algorithm]; ok { + if a.Generate != nil { + if _, ok := w.tsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + data, w.tsigRequestMAC, err = TsigGenerateByAlgorithm(m, a.Generate, t.Hdr.Name, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } + } + } else { + data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } } _, err = w.writer.Write(data) return err