diff --git a/server.go b/server.go index 2d98f1488..992e0b638 100644 --- a/server.go +++ b/server.go @@ -55,11 +55,12 @@ type response struct { tsigStatus error tsigTimersOnly bool tsigRequestMAC string - tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used - tcp net.Conn // i/o connection if TCP was used - udpSession *SessionUDP // oob data to get egress interface right - writer Writer // writer to output the raw DNS bits + tsigSecret map[string]string // the tsig secrets + tsigAlgorithm map[string]*TsigAlgorithm // custom tsig algorithms + udp *net.UDPConn // i/o connection if UDP was used + tcp net.Conn // i/o connection if TCP was used + udpSession *SessionUDP // oob data to get egress interface right + writer Writer // writer to output the raw DNS bits } // ServeMux is an DNS request multiplexer. It matches the @@ -294,6 +295,8 @@ type Server struct { IdleTimeout func() time.Duration // Secret(s) for Tsig map[]. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). TsigSecret map[string]string + // Non-HMAC Tsig algorithm callbacks map[]{generate(), verify()}. + TsigAlgorithm map[string]*TsigAlgorithm // Unsafe instructs the server to disregard any sanity checks and directly hand the message to // the handler. It will specifically not check if the query has the QR bit not set. Unsafe bool @@ -634,12 +637,22 @@ func (srv *Server) serveDNS(w *response) { } w.tsigStatus = nil - if w.tsigSecret != nil { + if w.tsigAlgorithm != nil || w.tsigSecret != nil { if t := req.IsTsig(); t != nil { - if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { - w.tsigStatus = TsigVerify(w.msg, secret, "", false) + if a, ok := w.tsigAlgorithm[t.Algorithm]; ok { + if a.Verify != nil { + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerifyByAlgorithm(w.msg, a.Verify, t.Hdr.Name, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } + } } else { - w.tsigStatus = ErrSecret + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerify(w.msg, secret, "", false) + } else { + w.tsigStatus = ErrSecret + } } w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC @@ -703,11 +716,23 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { var data []byte - if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) + if w.tsigAlgorithm != nil || w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { - data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) - if err != nil { - return err + if a, ok := w.tsigAlgorithm[t.Algorithm]; ok { + if a.Generate != nil { + if _, ok := w.tsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + data, w.tsigRequestMAC, err = TsigGenerateByAlgorithm(m, a.Generate, t.Hdr.Name, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } + } + } else { + data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if err != nil { + return err + } } _, err = w.writer.Write(data) return err