diff --git a/client.go b/client.go index c52eec1d5..01a40a491 100644 --- a/client.go +++ b/client.go @@ -29,6 +29,7 @@ type Conn struct { net.Conn // a net.Conn holding the connection UDPSize uint16 // minimum receive buffer for UDP messages TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigAlgorithm map[string]*TsigAlgorithm tsigRequestMAC string } @@ -47,7 +48,8 @@ type Client struct { WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) - SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass + TsigAlgorithm map[string]*TsigAlgorithm + SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight } @@ -194,6 +196,7 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro } co.TsigSecret = c.TsigSecret + co.TsigAlgorithm = c.TsigAlgorithm t := time.Now() // write with the appropriate write timeout co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) @@ -300,11 +303,20 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, err } if t := m.IsTsig(); t != nil { - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return m, ErrSecret + if a, ok := co.TsigAlgorithm[t.Algorithm]; ok { + if a.Verify != nil { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + err = TsigVerifyByAlgorithm(p, a.Verify, t.Hdr.Name, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + } + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } return m, err } @@ -437,10 +449,19 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { mac := "" - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return ErrSecret + if a, ok := co.TsigAlgorithm[t.Algorithm]; ok { + if a.Generate != nil { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerateByAlgorithm(m, a.Generate, t.Hdr.Name, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + } + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) // Set for the next read, although only used in zone transfers co.tsigRequestMAC = mac } else { diff --git a/server.go b/server.go index 2d98f1488..992e0b638 100644 --- a/server.go +++ b/server.go @@ -55,11 +55,12 @@ type response struct { tsigStatus error tsigTimersOnly bool tsigRequestMAC string - tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used - tcp net.Conn // i/o connection if TCP was used - udpSession *SessionUDP // oob data to get egress interface right - writer Writer // writer to output the raw DNS bits + tsigSecret map[string]string // the tsig secrets + tsigAlgorithm map[string]*TsigAlgorithm // custom tsig algorithms + udp *net.UDPConn // i/o connection if UDP was used + tcp net.Conn // i/o connection if TCP was used + udpSession *SessionUDP // oob data to get egress interface right + writer Writer // writer to output the raw DNS bits } // ServeMux is an DNS request multiplexer. It matches the @@ -294,6 +295,8 @@ type Server struct { IdleTimeout func() time.Duration // Secret(s) for Tsig map[]. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). TsigSecret map[string]string + // Non-HMAC Tsig algorithm callbacks map[]{generate(), verify()}. + TsigAlgorithm map[string]*TsigAlgorithm // Unsafe instructs the server to disregard any sanity checks and directly hand the message to // the handler. It will specifically not check if the query has the QR bit not set. Unsafe bool @@ -634,12 +637,22 @@ func (srv *Server) serveDNS(w *response) { } w.tsigStatus = nil - if w.tsigSecret != nil { + if w.tsigAlgorithm != nil || w.tsigSecret != nil { if t := req.IsTsig(); t != nil { - if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { - w.tsigStatus = TsigVerify(w.msg, secret, "", false) + if a, ok := w.tsigAlgorithm[t.Algorithm]; ok { + if a.Verify != nil { + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerifyByAlgorithm(w.msg, a.Verify, t.Hdr.Name, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } + } } else { - w.tsigStatus = ErrSecret + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerify(w.msg, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } } w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC @@ -703,11 +716,23 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { var data []byte - if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) + if w.tsigAlgorithm != nil || w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { - data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) - if err != nil { - return err + if a, ok := w.tsigAlgorithm[t.Algorithm]; ok { + if a.Generate != nil { + if _, ok := w.tsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + data, w.tsigRequestMAC, err = TsigGenerateByAlgorithm(m, a.Generate, t.Hdr.Name, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } + } + } else { + data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } } _, err = w.writer.Write(data) return err diff --git a/tsig.go b/tsig.go index 4837b4ab1..12678edbe 100644 --- a/tsig.go +++ b/tsig.go @@ -22,6 +22,14 @@ const ( HmacSHA512 = "hmac-sha512." ) +type TsigAlgorithm struct { + Generate tsigAlgorithmGenerate + Verify tsigAlgorithmVerify +} + +type tsigAlgorithmGenerate func(msg []byte, algorithm, name, secret string) ([]byte, error) +type tsigAlgorithmVerify func(msg []byte, tsig *TSIG, name, secret string) error + // TSIG is the RR the holds the transaction signature of a message. // See RFC 2845 and RFC 4635. type TSIG struct { @@ -83,6 +91,30 @@ type timerWireFmt struct { Fudge uint16 } +func tsigGenerateHmac(msg []byte, algorithm, name, secret string) ([]byte, error) { + rawsecret, err := fromBase64([]byte(secret)) + if err != nil { + return nil, err + } + + var h hash.Hash + switch strings.ToLower(algorithm) { + case HmacMD5: + h = hmac.New(md5.New, []byte(rawsecret)) + case HmacSHA1: + h = hmac.New(sha1.New, []byte(rawsecret)) + case HmacSHA256: + h = hmac.New(sha256.New, []byte(rawsecret)) + case HmacSHA512: + h = hmac.New(sha512.New, []byte(rawsecret)) + default: + return nil, ErrKeyAlg + } + h.Write(msg) + + return h.Sum(nil), nil +} + // TsigGenerate fills out the TSIG record attached to the message. // The message should contain // a "stub" TSIG RR with the algorithm, key name (owner name of the RR), @@ -92,14 +124,22 @@ type timerWireFmt struct { // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { + return TsigGenerateByAlgorithm(m, tsigGenerateHmac, "", secret, requestMAC, timersOnly) +} + +// TsigGenerateByAlgorithm fills out the TSIG record attached to the message +// using a callback to implement the algorithm-specific generation. +// The message should contain +// a "stub" TSIG RR with the algorithm, key name (owner name of the RR), +// time fudge (defaults to 300 seconds) and the current time +// The TSIG MAC is saved in that Tsig RR. +// When TsigGenerate is called for the first time requestMAC is set to the empty string and +// timersOnly is false. +// If something goes wrong an error is returned, otherwise it is nil. +func TsigGenerateByAlgorithm(m *Msg, cb tsigAlgorithmGenerate, name, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { if m.IsTsig() == nil { panic("dns: TSIG not last RR in additional") } - // If we barf here, the caller is to blame - rawsecret, err := fromBase64([]byte(secret)) - if err != nil { - return nil, "", err - } rr := m.Extra[len(m.Extra)-1].(*TSIG) m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg @@ -110,21 +150,13 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly) t := new(TSIG) - var h hash.Hash - switch strings.ToLower(rr.Algorithm) { - case HmacMD5: - h = hmac.New(md5.New, []byte(rawsecret)) - case HmacSHA1: - h = hmac.New(sha1.New, []byte(rawsecret)) - case HmacSHA256: - h = hmac.New(sha256.New, []byte(rawsecret)) - case HmacSHA512: - h = hmac.New(sha512.New, []byte(rawsecret)) - default: - return nil, "", ErrKeyAlg + + h, err := cb(buf, rr.Algorithm, name, secret) + if err != nil { + return nil, "", err } - h.Write(buf) - t.MAC = hex.EncodeToString(h.Sum(nil)) + + t.MAC = hex.EncodeToString(h) t.MACSize = uint16(len(t.MAC) / 2) // Size is half! t.Hdr = RR_Header{Name: rr.Hdr.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0} @@ -146,38 +178,17 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s return mbuf, t.MAC, nil } -// TsigVerify verifies the TSIG on a message. -// If the signature does not validate err contains the -// error, otherwise it is nil. -func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { +func tsigVerifyHmac(msg []byte, tsig *TSIG, name, secret string) error { rawsecret, err := fromBase64([]byte(secret)) if err != nil { return err } - // Strip the TSIG from the incoming msg - stripped, tsig, err := stripTsig(msg) - if err != nil { - return err - } msgMAC, err := hex.DecodeString(tsig.MAC) if err != nil { return err } - buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) - - // Fudge factor works both ways. A message can arrive before it was signed because - // of clock skew. - now := uint64(time.Now().Unix()) - ti := now - tsig.TimeSigned - if now < tsig.TimeSigned { - ti = tsig.TimeSigned - now - } - if uint64(tsig.Fudge) < ti { - return ErrTime - } - var h hash.Hash switch strings.ToLower(tsig.Algorithm) { case HmacMD5: @@ -191,13 +202,47 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { default: return ErrKeyAlg } - h.Write(buf) + h.Write(msg) if !hmac.Equal(h.Sum(nil), msgMAC) { return ErrSig } return nil } +// TsigVerify verifies the TSIG on a message. +// If the signature does not validate err contains the +// error, otherwise it is nil. +func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { + return TsigVerifyByAlgorithm(msg, tsigVerifyHmac, "", secret, requestMAC, timersOnly) +} + +// TsigVerifyByAlgorithm verifies the TSIG on a message using a callback to +// implement the algorithm-specific verification. +// If the signature does not validate err contains the +// error, otherwise it is nil. +func TsigVerifyByAlgorithm(msg []byte, cb tsigAlgorithmVerify, name, secret, requestMAC string, timersOnly bool) error { + // Strip the TSIG from the incoming msg + stripped, tsig, err := stripTsig(msg) + if err != nil { + return err + } + + buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) + + // Fudge factor works both ways. A message can arrive before it was signed because + // of clock skew. + now := uint64(time.Now().Unix()) + ti := now - tsig.TimeSigned + if now < tsig.TimeSigned { + ti = tsig.TimeSigned - now + } + if uint64(tsig.Fudge) < ti { + return ErrTime + } + + return cb(buf, tsig, name, secret) +} + // Create a wiredata buffer for the MAC calculation. func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte { var buf []byte