From 65ca515ca51afb57547b18e020930047784dfa3a Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 24 Oct 2024 10:07:27 +0000 Subject: [PATCH] chore(pkg/server): Start function takes a context argument --- examples/doh-server/main.go | 2 +- examples/dot-server/main.go | 2 +- internal/dns/interfaces.go | 2 +- internal/dns/loop.go | 4 ++-- internal/setup/dns.go | 3 ++- pkg/server/integration_test.go | 10 ++++++---- pkg/server/server.go | 4 ++-- 7 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/doh-server/main.go b/examples/doh-server/main.go index 34cb2949..00cadef6 100644 --- a/examples/doh-server/main.go +++ b/examples/doh-server/main.go @@ -43,7 +43,7 @@ func main() { log.Fatal(err) } - runError, err := server.Start() + runError, err := server.Start(ctx) if err != nil { log.Fatal(err) } diff --git a/examples/dot-server/main.go b/examples/dot-server/main.go index 4bf13070..b03235a3 100644 --- a/examples/dot-server/main.go +++ b/examples/dot-server/main.go @@ -43,7 +43,7 @@ func main() { log.Fatal(err) } - runError, err := server.Start() + runError, err := server.Start(ctx) if err != nil { log.Fatal(err) } diff --git a/internal/dns/interfaces.go b/internal/dns/interfaces.go index ddf40a87..59533955 100644 --- a/internal/dns/interfaces.go +++ b/internal/dns/interfaces.go @@ -11,7 +11,7 @@ import ( type Service interface { String() string - Start() (runError <-chan error, startErr error) + Start(ctx context.Context) (runError <-chan error, startErr error) Stop() error } diff --git a/internal/dns/loop.go b/internal/dns/loop.go index c7185e70..e51bc53b 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -121,7 +121,7 @@ func (l *Loop) runFirst(ctx context.Context) (err error) { } l.logger.Info("starting DNS server") - _, err = l.dnsServer.Start() + _, err = l.dnsServer.Start(ctx) if err != nil { return fmt.Errorf("starting dns server: %w", err) } @@ -155,7 +155,7 @@ func (l *Loop) runSubsequent(ctx context.Context, l.dnsServer = newDNSServer l.logger.Info("starting DNS server") - serverRunError, startErr := l.dnsServer.Start() + serverRunError, startErr := l.dnsServer.Start(ctx) if startErr != nil { return fmt.Errorf("starting dns server: %w", startErr) } diff --git a/internal/setup/dns.go b/internal/setup/dns.go index 7dc6e858..f61df38b 100644 --- a/internal/setup/dns.go +++ b/internal/setup/dns.go @@ -1,6 +1,7 @@ package setup import ( + "context" "fmt" "github.com/qdm12/dns/v2/internal/config" @@ -14,7 +15,7 @@ import ( type Service interface { String() string - Start() (runError <-chan error, startErr error) + Start(ctx context.Context) (runError <-chan error, startErr error) Stop() (err error) } diff --git a/pkg/server/integration_test.go b/pkg/server/integration_test.go index d0ac4150..945d3e73 100644 --- a/pkg/server/integration_test.go +++ b/pkg/server/integration_test.go @@ -30,7 +30,8 @@ func Test_Server(t *testing.T) { }) require.NoError(t, err) - runError, startErr := server.Start() + ctx := context.Background() + runError, startErr := server.Start(ctx) require.NoError(t, startErr) listeningAddress, err := server.ListeningAddress() @@ -46,7 +47,7 @@ func Test_Server(t *testing.T) { } const hostname = "google.com" // we use google.com as github.com doesn't have an IPv6 :( - ips, err := resolver.LookupIPAddr(context.Background(), hostname) + ips, err := resolver.LookupIPAddr(ctx, hostname) require.NoError(t, err) require.NotEmpty(t, ips) @@ -259,7 +260,8 @@ func Test_Server_Mocks(t *testing.T) { }) require.NoError(t, err) - runError, startErr := server.Start() + ctx := context.Background() + runError, startErr := server.Start(ctx) require.NoError(t, startErr) listeningAddress, err := server.ListeningAddress() @@ -275,7 +277,7 @@ func Test_Server_Mocks(t *testing.T) { } const hostname = "google.com" - ips, err := resolver.LookupIPAddr(context.Background(), hostname) + ips, err := resolver.LookupIPAddr(ctx, hostname) assert.NoError(t, err) assert.NotEmpty(t, ips) t.Log(ips) diff --git a/pkg/server/server.go b/pkg/server/server.go index 23a683e7..ccd20007 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -45,7 +45,7 @@ func (s *Server) String() string { return "DNS server" } -func (s *Server) Start() (runError <-chan error, startErr error) { +func (s *Server) Start(_ context.Context) (runError <-chan error, startErr error) { s.startStopMutex.Lock() defer s.startStopMutex.Unlock() @@ -64,7 +64,7 @@ func (s *Server) Start() (runError <-chan error, startErr error) { var handler dns.Handler exchanger := exchanger.New(s.settings.Dialer, s.logger) - handler = newHandler(handlerCtx, exchanger, s.logger) + handler = newHandler(handlerCtx, exchanger, s.logger) //nolint:contextcheck for _, middleware := range s.settings.Middlewares { handler = middleware.Wrap(handler)