Skip to content

Commit

Permalink
Write the dns message response body when it is not a ServerFailure (#984
Browse files Browse the repository at this point in the history
)

Coredns assumes ServeDNS has written the response body unless
the returned code indicates it hasn't been written as determined
by ClientWrite (which is only the codes RcodeServerFailure,
RcodeRefused, RcodeFormatError and RcodeNotImplemented),
in which case it writes the response.

If there's no next plugin, it simply returns the code but if
it isn't one of the ClientWrite codes, then no response is written.

If no response,client will try to resend dns request multiple times,
which causes seconds delay.

Fixes #859
Signed-off-by: blue-troy <[email protected]>

Co-authored-by: blue-troy <[email protected]>
  • Loading branch information
aswinsuryan and blue-troy authored Nov 29, 2022
1 parent dff35d1 commit f1f06c6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
25 changes: 16 additions & 9 deletions coredns/plugin/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
zone := plugin.Zones(lh.Zones).Matches(qname)
if zone == "" {
log.Debugf("Request does not match configured zones %v", lh.Zones)
return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNotZone, "No matching zone found")
return lh.nextOrFailure(ctx, state, r, dns.RcodeNotZone)
}

if state.QType() != dns.TypeA && state.QType() != dns.TypeAAAA && state.QType() != dns.TypeSRV {
msg := fmt.Sprintf("Query of type %d is not supported", state.QType())
log.Debugf(msg)

return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNotImplemented, msg)
return lh.nextOrFailure(ctx, state, r, dns.RcodeNotImplemented)
}

zone = qname[len(qname)-len(zone):] // maintain case of original query
Expand All @@ -61,7 +61,7 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
if pErr != nil || pReq.podOrSvc != Svc {
// We only support svc type queries i.e. *.svc.*
log.Debugf("Request type %q is not a 'svc' type query - err was %v", pReq.podOrSvc, pErr)
return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNameError, "Only services supported")
return lh.nextOrFailure(ctx, state, r, dns.RcodeNameError)
}

return lh.getDNSRecord(ctx, zone, state, w, r, pReq)
Expand All @@ -83,7 +83,7 @@ func (lh *Lighthouse) getDNSRecord(ctx context.Context, zone string, state *requ
pReq.service, lh.ClusterStatus.IsConnected)
if !found {
log.Debugf("No record found for %q", state.QName())
return lh.nextOrFailure(ctx, state.Name(), w, r, dns.RcodeNameError, "record not found")
return lh.nextOrFailure(ctx, state, r, dns.RcodeNameError)
}

isHeadless = true
Expand Down Expand Up @@ -141,11 +141,15 @@ func (lh *Lighthouse) getDNSRecord(ctx context.Context, zone string, state *requ
func (lh *Lighthouse) emptyResponse(state *request.Request) (int, error) {
a := new(dns.Msg)
a.SetReply(state.Req)

return lh.writeResponse(state, a)
}

func (lh *Lighthouse) writeResponse(state *request.Request, a *dns.Msg) (int, error) {
a.Authoritative = true

wErr := state.W.WriteMsg(a)
if wErr != nil {
// Error writing reply msg
log.Errorf("Failed to write message %#v: %v", a, wErr)
return dns.RcodeServerFailure, lh.error("failed to write response")
}
Expand All @@ -162,10 +166,13 @@ func (lh *Lighthouse) error(str string) error {
return plugin.Error(lh.Name(), errors.New(str)) // nolint:wrapcheck // Let the caller wrap it.
}

func (lh *Lighthouse) nextOrFailure(ctx context.Context, name string, w dns.ResponseWriter, r *dns.Msg, code int, err string) (int, error) {
if lh.Fall.Through(name) {
return plugin.NextOrFailure(lh.Name(), lh.Next, ctx, w, r) // nolint:wrapcheck // Let the caller wrap it.
func (lh *Lighthouse) nextOrFailure(ctx context.Context, state *request.Request, r *dns.Msg, rcode int) (int, error) {
if lh.Fall.Through(state.Name()) {
return plugin.NextOrFailure(lh.Name(), lh.Next, ctx, state.W, r) // nolint:wrapcheck // Let the caller wrap it.
}

return code, lh.error(err)
a := new(dns.Msg)
a.SetRcode(r, rcode)

return lh.writeResponse(state, a)
}
13 changes: 9 additions & 4 deletions coredns/plugin/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"fmt"

"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/test"
Expand Down Expand Up @@ -328,7 +329,12 @@ func testWithFallback() {
t.mockCs.localClusterID = clusterID
t.mockEs.endpointStatusMap[clusterID] = true
t.lh.Fall = fall.F{Zones: []string{"clusterset.local."}}
t.lh.Next = test.NextHandler(dns.RcodeBadCookie, errors.New("dummy plugin"))
t.lh.Next = test.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeBadCookie)
_ = w.WriteMsg(m)
return dns.RcodeBadCookie, nil
})

rec = dnstest.NewRecorder(&test.ResponseWriter{})
})
Expand Down Expand Up @@ -1008,13 +1014,12 @@ func newHandlerTestDriver() *handlerTestDriver {
func (t *handlerTestDriver) executeTestCase(rec *dnstest.Recorder, tc test.Case) {
code, err := t.lh.ServeDNS(context.TODO(), rec, tc.Msg())

Expect(code).Should(Equal(tc.Rcode))

if tc.Rcode == dns.RcodeSuccess {
if plugin.ClientWrite(tc.Rcode) {
Expect(err).To(Succeed())
Expect(test.SortAndCheck(rec.Msg, tc)).To(Succeed())
} else {
Expect(err).To(HaveOccurred())
Expect(code).Should(Equal(tc.Rcode))
}
}

Expand Down

0 comments on commit f1f06c6

Please sign in to comment.