Skip to content

Commit

Permalink
chore(middlewares): stateful writer does not write to injected respon…
Browse files Browse the repository at this point in the history
…se writer
  • Loading branch information
qdm12 committed Aug 10, 2023
1 parent f7e0a9e commit e694c83
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 44 deletions.
4 changes: 4 additions & 0 deletions internal/setup/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func DNS(userSettings settings.Settings, //nolint:ireturn
}
middlewares = append(middlewares, middlewareMetrics)

// Log middleware should be one of the top most middlewares
// so it actually calls `.WriteMsg` on an actual dns.ResponseWriter
// writing to the network. Having it as the last element of the
// middlewares slice achieves this.
logMiddleware, err := logMiddleware(userSettings.MiddlewareLog)
if err != nil {
return nil, fmt.Errorf("log middleware: %w", err)
Expand Down
8 changes: 6 additions & 2 deletions pkg/middlewares/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ type Logger interface {
Error(id uint16, errorString string)
}

// ServeDNS implements the dns.Handler interface.
// Note the response writer passed as argument should be an actual
// IO writer, not a buffered writer, so it can return an actual error.
func (h *handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
sw := stateful.NewWriter(w)
sw := stateful.NewWriter()
h.next.ServeDNS(sw, r)
h.logger.Log(w.RemoteAddr(), r, sw.Response)
if err := sw.WriteErr; err != nil {
err := w.WriteMsg(sw.Response)
if err != nil {
errString := "cannot write DNS response: " + err.Error()
h.logger.Error(r.Id, errString)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/middlewares/log/log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func Test_New(t *testing.T) {

writer := NewMockResponseWriter(ctrl)
writer.EXPECT().RemoteAddr().Return(remoteAddress)
writer.EXPECT().WriteMsg(nil).Return(nil)

handler.ServeDNS(writer, request)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/middlewares/metrics/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
h.metrics.QuestionsInc(class, qType)
}

statefulWriter := stateful.NewWriter(w)
statefulWriter := stateful.NewWriter()
h.next.ServeDNS(statefulWriter, request)
response := statefulWriter.Response

Expand All @@ -60,6 +60,8 @@ func (h *handler) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
}

h.metrics.ResponsesInc()

_ = w.WriteMsg(statefulWriter.Response)
}

func rcodeToString(rcode int) (rcodeString string) {
Expand Down
57 changes: 47 additions & 10 deletions pkg/middlewares/stateful/writer.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,62 @@
package stateful

import (
"net"

"github.com/miekg/dns"
)

// Writer wraps the dns writer in order to report
// the dns response written and eventual error.
// Writer is a stateful writer with the Response field
// set when WriteMsg is called. Only the WriteMsg method
// is implemented, calls to the other methods will panic.
type Writer struct {
dns.ResponseWriter
Response *dns.Msg
WriteErr error
}

// NewWriter creates a new stateful writer.
func NewWriter() *Writer {
return &Writer{}
}

// WriteMsg sets the Response field of the Writer
// to the given response message and always returns
// a nil error.
func (w *Writer) WriteMsg(response *dns.Msg) error {
w.Response = response
w.WriteErr = w.ResponseWriter.WriteMsg(response)
return w.WriteErr
return nil
}

// LocalAddr will panic if called.
func (w *Writer) LocalAddr() net.Addr {
panic("not implemented")
}

// RemoteAddr will panic if called.
func (w *Writer) RemoteAddr() net.Addr {
panic("not implemented")
}

// Write will panic if called.
func (w *Writer) Write([]byte) (int, error) {
panic("not implemented")
}

// Close will panic if called.
func (w *Writer) Close() error {
panic("not implemented")
}

// TsigStatus will panic if called.
func (w *Writer) TsigStatus() error {
panic("not implemented")
}

// TsigTimersOnly will panic if called.
func (w *Writer) TsigTimersOnly(bool) {
panic("not implemented")
}

func NewWriter(w dns.ResponseWriter) *Writer {
return &Writer{
ResponseWriter: w,
}
// Hijack will panic if called.
func (w *Writer) Hijack() {
panic("not implemented")
}
34 changes: 3 additions & 31 deletions pkg/middlewares/stateful/writer_test.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,26 @@
package stateful

import (
"errors"
"net"
"testing"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type testWriter struct {
t *testing.T
expectedResponse *dns.Msg
err error
// to have methods other than WriteMsg we don't use in our tests
dns.ResponseWriter
}

func (w *testWriter) WriteMsg(response *dns.Msg) error {
assert.Equal(w.t, w.expectedResponse, response)
return w.err
}

func Test_Writer(t *testing.T) {
t.Parallel()

var dummyResponse = &dns.Msg{Answer: []dns.RR{
&dns.A{A: net.IP{1, 2, 3, 4}},
}}
var errDummy = errors.New("dummy")

testCases := map[string]struct {
response *dns.Msg
err error
}{
"nil response and nil error": {},
"response and error": {
response: dummyResponse,
err: errDummy,
},
"response and no error": {
response: dummyResponse,
Expand All @@ -50,22 +32,12 @@ func Test_Writer(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

w := &testWriter{
t: t,
expectedResponse: testCase.response,
err: testCase.err,
}

writer := NewWriter(w)
writer := NewWriter()

err := writer.WriteMsg(testCase.response)

if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.response, writer.Response)
assert.NoError(t, err)
})
}
}

0 comments on commit e694c83

Please sign in to comment.