Skip to content

Commit

Permalink
chore(pkg/server): Start function takes a context argument
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Oct 24, 2024
1 parent abc1371 commit 65ca515
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/doh-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func main() {
log.Fatal(err)
}

runError, err := server.Start()
runError, err := server.Start(ctx)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/dot-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func main() {
log.Fatal(err)
}

runError, err := server.Start()
runError, err := server.Start(ctx)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/dns/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

type Service interface {
String() string
Start() (runError <-chan error, startErr error)
Start(ctx context.Context) (runError <-chan error, startErr error)
Stop() error
}

Expand Down
4 changes: 2 additions & 2 deletions internal/dns/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (l *Loop) runFirst(ctx context.Context) (err error) {
}

l.logger.Info("starting DNS server")
_, err = l.dnsServer.Start()
_, err = l.dnsServer.Start(ctx)
if err != nil {
return fmt.Errorf("starting dns server: %w", err)
}
Expand Down Expand Up @@ -155,7 +155,7 @@ func (l *Loop) runSubsequent(ctx context.Context,
l.dnsServer = newDNSServer

l.logger.Info("starting DNS server")
serverRunError, startErr := l.dnsServer.Start()
serverRunError, startErr := l.dnsServer.Start(ctx)
if startErr != nil {
return fmt.Errorf("starting dns server: %w", startErr)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/setup/dns.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package setup

import (
"context"
"fmt"

"github.com/qdm12/dns/v2/internal/config"
Expand All @@ -14,7 +15,7 @@ import (

type Service interface {
String() string
Start() (runError <-chan error, startErr error)
Start(ctx context.Context) (runError <-chan error, startErr error)
Stop() (err error)
}

Expand Down
10 changes: 6 additions & 4 deletions pkg/server/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ func Test_Server(t *testing.T) {
})
require.NoError(t, err)

runError, startErr := server.Start()
ctx := context.Background()
runError, startErr := server.Start(ctx)
require.NoError(t, startErr)

listeningAddress, err := server.ListeningAddress()
Expand All @@ -46,7 +47,7 @@ func Test_Server(t *testing.T) {
}

const hostname = "google.com" // we use google.com as github.com doesn't have an IPv6 :(
ips, err := resolver.LookupIPAddr(context.Background(), hostname)
ips, err := resolver.LookupIPAddr(ctx, hostname)

require.NoError(t, err)
require.NotEmpty(t, ips)
Expand Down Expand Up @@ -259,7 +260,8 @@ func Test_Server_Mocks(t *testing.T) {
})
require.NoError(t, err)

runError, startErr := server.Start()
ctx := context.Background()
runError, startErr := server.Start(ctx)
require.NoError(t, startErr)

listeningAddress, err := server.ListeningAddress()
Expand All @@ -275,7 +277,7 @@ func Test_Server_Mocks(t *testing.T) {
}

const hostname = "google.com"
ips, err := resolver.LookupIPAddr(context.Background(), hostname)
ips, err := resolver.LookupIPAddr(ctx, hostname)
assert.NoError(t, err)
assert.NotEmpty(t, ips)
t.Log(ips)
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (s *Server) String() string {
return "DNS server"
}

func (s *Server) Start() (runError <-chan error, startErr error) {
func (s *Server) Start(_ context.Context) (runError <-chan error, startErr error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()

Expand All @@ -64,7 +64,7 @@ func (s *Server) Start() (runError <-chan error, startErr error) {

var handler dns.Handler
exchanger := exchanger.New(s.settings.Dialer, s.logger)
handler = newHandler(handlerCtx, exchanger, s.logger)
handler = newHandler(handlerCtx, exchanger, s.logger) //nolint:contextcheck

for _, middleware := range s.settings.Middlewares {
handler = middleware.Wrap(handler)
Expand Down

0 comments on commit 65ca515

Please sign in to comment.