diff --git a/coredns/plugin/handler.go b/coredns/plugin/handler.go index eb8b71e32..8804eee67 100644 --- a/coredns/plugin/handler.go +++ b/coredns/plugin/handler.go @@ -44,14 +44,14 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns zone := plugin.Zones(lh.Zones).Matches(qname) if zone == "" { log.Debugf("Request does not match configured zones %v", lh.Zones) - return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNotZone, "No matching zone found") + return lh.nextOrFailure(ctx, state, r, dns.RcodeNotZone) } if state.QType() != dns.TypeA && state.QType() != dns.TypeAAAA && state.QType() != dns.TypeSRV { msg := fmt.Sprintf("Query of type %d is not supported", state.QType()) log.Debugf(msg) - return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNotImplemented, msg) + return lh.nextOrFailure(ctx, state, r, dns.RcodeNotImplemented) } zone = qname[len(qname)-len(zone):] // maintain case of original query @@ -61,7 +61,7 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns if pErr != nil || pReq.podOrSvc != Svc { // We only support svc type queries i.e. *.svc.* log.Debugf("Request type %q is not a 'svc' type query - err was %v", pReq.podOrSvc, pErr) - return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNameError, "Only services supported") + return lh.nextOrFailure(ctx, state, r, dns.RcodeNameError) } return lh.getDNSRecord(ctx, zone, state, w, r, pReq) @@ -83,7 +83,7 @@ func (lh *Lighthouse) getDNSRecord(ctx context.Context, zone string, state *requ pReq.service, lh.ClusterStatus.IsConnected) if !found { log.Debugf("No record found for %q", state.QName()) - return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNameError, "record not found") + return lh.nextOrFailure(ctx, state, r, dns.RcodeNameError) } isHeadless = true @@ -141,11 +141,15 @@ func (lh *Lighthouse) getDNSRecord(ctx context.Context, zone string, state *requ func (lh *Lighthouse) emptyResponse(state *request.Request) (int, error) { a := new(dns.Msg) a.SetReply(state.Req) + + return lh.writeResponse(state, a) +} + +func (lh *Lighthouse) writeResponse(state *request.Request, a *dns.Msg) (int, error) { a.Authoritative = true wErr := state.W.WriteMsg(a) if wErr != nil { - // Error writing reply msg log.Errorf("Failed to write message %#v: %v", a, wErr) return dns.RcodeServerFailure, lh.error("failed to write response") } @@ -162,10 +166,13 @@ func (lh *Lighthouse) error(str string) error { return plugin.Error(lh.Name(), errors.New(str)) // nolint:wrapcheck // Let the caller wrap it. } -func (lh *Lighthouse) nextOrFailure(ctx context.Context, name string, w dns.ResponseWriter, r *dns.Msg, code int, err string) (int, error) { - if lh.Fall.Through(name) { - return plugin.NextOrFailure(lh.Name(), lh.Next, ctx, w, r) // nolint:wrapcheck // Let the caller wrap it. +func (lh *Lighthouse) nextOrFailure(ctx context.Context, state *request.Request, r *dns.Msg, rcode int) (int, error) { + if lh.Fall.Through(state.Name()) { + return plugin.NextOrFailure(lh.Name(), lh.Next, ctx, state.W, r) // nolint:wrapcheck // Let the caller wrap it. } - return code, lh.error(err) + a := new(dns.Msg) + a.SetRcode(r, rcode) + + return lh.writeResponse(state, a) } diff --git a/coredns/plugin/handler_test.go b/coredns/plugin/handler_test.go index e099e2b93..4dabe989f 100644 --- a/coredns/plugin/handler_test.go +++ b/coredns/plugin/handler_test.go @@ -22,6 +22,7 @@ import ( "context" "fmt" + "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/fall" "github.com/coredns/coredns/plugin/test" @@ -328,7 +329,12 @@ func testWithFallback() { t.mockCs.localClusterID = clusterID t.mockEs.endpointStatusMap[clusterID] = true t.lh.Fall = fall.F{Zones: []string{"clusterset.local."}} - t.lh.Next = test.NextHandler(dns.RcodeBadCookie, errors.New("dummy plugin")) + t.lh.Next = test.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeBadCookie) + _ = w.WriteMsg(m) + return dns.RcodeBadCookie, nil + }) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) @@ -1008,13 +1014,12 @@ func newHandlerTestDriver() *handlerTestDriver { func (t *handlerTestDriver) executeTestCase(rec *dnstest.Recorder, tc test.Case) { code, err := t.lh.ServeDNS(context.TODO(), rec, tc.Msg()) - Expect(code).Should(Equal(tc.Rcode)) - - if tc.Rcode == dns.RcodeSuccess { + if plugin.ClientWrite(tc.Rcode) { Expect(err).To(Succeed()) Expect(test.SortAndCheck(rec.Msg, tc)).To(Succeed()) } else { Expect(err).To(HaveOccurred()) + Expect(code).Should(Equal(tc.Rcode)) } }