diff --git a/internal/setup/dns.go b/internal/setup/dns.go index 254d684f..d1f9af3c 100644 --- a/internal/setup/dns.go +++ b/internal/setup/dns.go @@ -32,6 +32,10 @@ func DNS(userSettings settings.Settings, //nolint:ireturn } middlewares = append(middlewares, middlewareMetrics) + // Log middleware should be one of the top most middlewares + // so it actually calls `.WriteMsg` on an actual dns.ResponseWriter + // writing to the network. Having it as the last element of the + // middlewares slice achieves this. logMiddleware, err := logMiddleware(userSettings.MiddlewareLog) if err != nil { return nil, fmt.Errorf("log middleware: %w", err) diff --git a/pkg/middlewares/log/log.go b/pkg/middlewares/log/log.go index 6598bdca..b97fa1d8 100644 --- a/pkg/middlewares/log/log.go +++ b/pkg/middlewares/log/log.go @@ -40,11 +40,15 @@ type Logger interface { Error(id uint16, errorString string) } +// ServeDNS implements the dns.Handler interface. +// Note the response writer passed as argument should be an actual +// IO writer, not a buffered writer, so it can return an actual error. func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - sw := stateful.NewWriter(w) + sw := stateful.NewWriter() h.next.ServeDNS(sw, r) h.logger.Log(w.RemoteAddr(), r, sw.Response) - if err := sw.WriteErr; err != nil { + err := w.WriteMsg(sw.Response) + if err != nil { errString := "cannot write DNS response: " + err.Error() h.logger.Error(r.Id, errString) } diff --git a/pkg/middlewares/log/log_test.go b/pkg/middlewares/log/log_test.go index d492ce11..e7507a1d 100644 --- a/pkg/middlewares/log/log_test.go +++ b/pkg/middlewares/log/log_test.go @@ -36,6 +36,7 @@ func Test_New(t *testing.T) { writer := NewMockResponseWriter(ctrl) writer.EXPECT().RemoteAddr().Return(remoteAddress) + writer.EXPECT().WriteMsg(nil).Return(nil) handler.ServeDNS(writer, request) } diff --git a/pkg/middlewares/metrics/middleware.go b/pkg/middlewares/metrics/middleware.go index c5b0fada..9a8c1a62 100644 --- a/pkg/middlewares/metrics/middleware.go +++ b/pkg/middlewares/metrics/middleware.go @@ -45,7 +45,7 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, request *dns.Msg) { h.metrics.QuestionsInc(class, qType) } - statefulWriter := stateful.NewWriter(w) + statefulWriter := stateful.NewWriter() h.next.ServeDNS(statefulWriter, request) response := statefulWriter.Response @@ -60,6 +60,8 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, request *dns.Msg) { } h.metrics.ResponsesInc() + + _ = w.WriteMsg(statefulWriter.Response) } func rcodeToString(rcode int) (rcodeString string) { diff --git a/pkg/middlewares/stateful/writer.go b/pkg/middlewares/stateful/writer.go index ea7c8c29..051ec91b 100644 --- a/pkg/middlewares/stateful/writer.go +++ b/pkg/middlewares/stateful/writer.go @@ -1,25 +1,62 @@ package stateful import ( + "net" + "github.com/miekg/dns" ) -// Writer wraps the dns writer in order to report -// the dns response written and eventual error. +// Writer is a stateful writer with the Response field +// set when WriteMsg is called. Only the WriteMsg method +// is implemented, calls to the other methods will panic. type Writer struct { - dns.ResponseWriter Response *dns.Msg - WriteErr error } +// NewWriter creates a new stateful writer. +func NewWriter() *Writer { + return &Writer{} +} + +// WriteMsg sets the Response field of the Writer +// to the given response message and always returns +// a nil error. func (w *Writer) WriteMsg(response *dns.Msg) error { w.Response = response - w.WriteErr = w.ResponseWriter.WriteMsg(response) - return w.WriteErr + return nil +} + +// LocalAddr will panic if called. +func (w *Writer) LocalAddr() net.Addr { + panic("not implemented") +} + +// RemoteAddr will panic if called. +func (w *Writer) RemoteAddr() net.Addr { + panic("not implemented") +} + +// Write will panic if called. +func (w *Writer) Write([]byte) (int, error) { + panic("not implemented") +} + +// Close will panic if called. +func (w *Writer) Close() error { + panic("not implemented") +} + +// TsigStatus will panic if called. +func (w *Writer) TsigStatus() error { + panic("not implemented") +} + +// TsigTimersOnly will panic if called. +func (w *Writer) TsigTimersOnly(bool) { + panic("not implemented") } -func NewWriter(w dns.ResponseWriter) *Writer { - return &Writer{ - ResponseWriter: w, - } +// Hijack will panic if called. +func (w *Writer) Hijack() { + panic("not implemented") } diff --git a/pkg/middlewares/stateful/writer_test.go b/pkg/middlewares/stateful/writer_test.go index f67804db..6ca7b411 100644 --- a/pkg/middlewares/stateful/writer_test.go +++ b/pkg/middlewares/stateful/writer_test.go @@ -1,44 +1,26 @@ package stateful import ( - "errors" "net" "testing" "github.com/miekg/dns" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -type testWriter struct { - t *testing.T - expectedResponse *dns.Msg - err error - // to have methods other than WriteMsg we don't use in our tests - dns.ResponseWriter -} - -func (w *testWriter) WriteMsg(response *dns.Msg) error { - assert.Equal(w.t, w.expectedResponse, response) - return w.err -} - func Test_Writer(t *testing.T) { t.Parallel() var dummyResponse = &dns.Msg{Answer: []dns.RR{ &dns.A{A: net.IP{1, 2, 3, 4}}, }} - var errDummy = errors.New("dummy") testCases := map[string]struct { response *dns.Msg - err error }{ "nil response and nil error": {}, "response and error": { response: dummyResponse, - err: errDummy, }, "response and no error": { response: dummyResponse, @@ -50,22 +32,12 @@ func Test_Writer(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - w := &testWriter{ - t: t, - expectedResponse: testCase.response, - err: testCase.err, - } - - writer := NewWriter(w) + writer := NewWriter() err := writer.WriteMsg(testCase.response) - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } + assert.Equal(t, testCase.response, writer.Response) + assert.NoError(t, err) }) } }