diff --git a/LICENSE-3rdparty.csv b/LICENSE-3rdparty.csv index dd7a9cd143cd0..9b337580421d8 100644 --- a/LICENSE-3rdparty.csv +++ b/LICENSE-3rdparty.csv @@ -1184,6 +1184,7 @@ core,github.com/mdlayher/netlink/nlenc,MIT,Copyright (C) 2016-2022 Matt Layher core,github.com/mdlayher/socket,MIT,Copyright (C) 2021 Matt Layher core,github.com/mholt/archiver/v3,MIT,Copyright (c) 2016 Matthew Holt core,github.com/microsoft/go-rustaudit,MIT,Copyright (c) Microsoft Corporation +core,github.com/miekg/dns,BSD-3-Clause,"Alex A. Skinner | Alex Sergeyev | Andrew Tunnell-Jones | Ask Bjørn Hansen | Copyright (c) 2009, The Go Authors. Extensions copyright (c) 2011, Miek Gieben | Copyright 2009 The Go Authors. All rights reserved. Use of this source code | Copyright 2011 Miek Gieben. All rights reserved. Use of this source code is | Copyright 2014 CloudFlare. All rights reserved. Use of this source code is | Dave Cheney | Dusty Wilson | James Hartig | Marek Majkowski | Miek Gieben | Omri Bahumi | Peter van Dijk | copyright (c) 2011 Miek Gieben" core,github.com/mitchellh/copystructure,MIT,Copyright (c) 2014 Mitchell Hashimoto core,github.com/mitchellh/go-homedir,MIT,Copyright (c) 2013 Mitchell Hashimoto core,github.com/mitchellh/hashstructure/v2,MIT,Copyright (c) 2016 Mitchell Hashimoto @@ -1850,6 +1851,7 @@ core,golang.org/x/net/internal/socket,BSD-3-Clause,Copyright (c) 2009 The Go Aut core,golang.org/x/net/internal/socks,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved core,golang.org/x/net/internal/timeseries,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved core,golang.org/x/net/ipv4,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved +core,golang.org/x/net/ipv6,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved core,golang.org/x/net/proxy,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved core,golang.org/x/net/trace,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved core,golang.org/x/net/websocket,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved diff --git a/pkg/network/dns/snooper_test.go b/pkg/network/dns/snooper_test.go index 25f41a1f90ec5..a0e60f046dbb4 100644 --- a/pkg/network/dns/snooper_test.go +++ b/pkg/network/dns/snooper_test.go @@ -8,7 +8,6 @@ package dns import ( - "fmt" "net" "strconv" "syscall" @@ -16,12 +15,12 @@ import ( "time" "github.com/google/gopacket/layers" - "github.com/miekg/dns" mdns "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/DataDog/datadog-agent/pkg/network/config" + "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -59,7 +58,7 @@ func TestDNSOverUDPSnooping(t *testing.T) { defer reverseDNS.Close() // Connect to golang.org. This will result in a DNS lookup which will be captured by socketFilterSnooper - _, _, reps := sendDNSQueries(t, []string{"golang.org"}, validDNSServerIP, "udp") + _, _, reps := sendDNSQueries(t, []string{"golang.org"}, testdns.GetServerIP(t), "udp") rep := reps[0] require.NotNil(t, rep) require.Equal(t, rep.Rcode, mdns.RcodeSuccess) @@ -80,7 +79,7 @@ func TestDNSOverTCPSnooping(t *testing.T) { reverseDNS := initDNSTestsWithDomainCollection(t, false) defer reverseDNS.Close() - _, _, reps := sendDNSQueries(t, []string{"golang.org"}, validDNSServerIP, "tcp") + _, _, reps := sendDNSQueries(t, []string{"golang.org"}, testdns.GetServerIP(t), "tcp") rep := reps[0] require.NotNil(t, rep) require.Equal(t, rep.Rcode, mdns.RcodeSuccess) @@ -95,11 +94,11 @@ func TestDNSOverTCPSnooping(t *testing.T) { } // Get the preferred outbound IP of this machine -func getOutboundIP(t *testing.T, serverIP string) net.IP { - if parsedIP := net.ParseIP(serverIP); parsedIP.IsLoopback() { - return parsedIP +func getOutboundIP(t *testing.T, serverIP net.IP) net.IP { + if serverIP.IsLoopback() { + return serverIP } - conn, err := net.Dial("udp", serverIP+":80") + conn, err := net.Dial("udp", serverIP.String()+":80") require.NoError(t, err) defer conn.Close() localAddr := conn.LocalAddr().(*net.UDPAddr) @@ -107,8 +106,7 @@ func getOutboundIP(t *testing.T, serverIP string) net.IP { } const ( - localhost = "127.0.0.1" - validDNSServerIP = "8.8.8.8" + localhost = "127.0.0.1" ) func initDNSTestsWithDomainCollection(t *testing.T, localDNS bool) *dnsMonitor { @@ -131,13 +129,13 @@ func initDNSTests(t *testing.T, localDNS bool, collectDomain bool) *dnsMonitor { func sendDNSQueries( t *testing.T, domains []string, - serverIP string, + serverIP net.IP, protocol string, ) (string, int, []*mdns.Msg) { return sendDNSQueriesOnPort(t, domains, serverIP, "53", protocol) } -func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP string, port string, protocol string) (string, int, []*mdns.Msg) { +func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP net.IP, port string, protocol string) (string, int, []*mdns.Msg) { // Create a DNS query message msg := new(mdns.Msg) msg.RecursionDesired = true @@ -156,11 +154,10 @@ func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP string, port } dnsClient := mdns.Client{Net: protocol, Dialer: localAddrDialer} - dnsHost := net.JoinHostPort(serverIP, port) + dnsHost := net.JoinHostPort(serverIP.String(), port) conn, err := dnsClient.Dial(dnsHost) require.NoError(t, err) - var reps []*mdns.Msg var queryPort int if protocol == "tcp" { queryPort = conn.Conn.(*net.TCPConn).LocalAddr().(*net.TCPAddr).Port @@ -168,6 +165,7 @@ func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP string, port queryPort = conn.Conn.(*net.UDPConn).LocalAddr().(*net.UDPAddr).Port } + var reps []*mdns.Msg for _, domain := range domains { msg.SetQuestion(mdns.Fqdn(domain), mdns.TypeA) rep, _, _ := dnsClient.ExchangeWithConn(msg, conn) @@ -215,7 +213,7 @@ func countDNSResponses(statsByDomain map[Hostname]map[QueryType]Stats) int { } func TestDNSOverTCPSuccessfulResponseCountWithoutDomain(t *testing.T) { - reverseDNS := initDNSTests(t, false, false) + reverseDNS := initDNSTests(t, true, false) defer reverseDNS.Close() statKeeper := reverseDNS.statKeeper domains := []string{ @@ -223,7 +221,7 @@ func TestDNSOverTCPSuccessfulResponseCountWithoutDomain(t *testing.T) { "google.com", "acm.org", } - queryIP, queryPort, reps := sendDNSQueries(t, domains, validDNSServerIP, "tcp") + queryIP, queryPort, reps := sendDNSQueries(t, domains, testdns.GetServerIP(t), "tcp") // Check that all the queries succeeded for _, rep := range reps { @@ -231,7 +229,7 @@ func TestDNSOverTCPSuccessfulResponseCountWithoutDomain(t *testing.T) { require.Equal(t, rep.Rcode, mdns.RcodeSuccess) } - key := getKey(queryIP, queryPort, validDNSServerIP, syscall.IPPROTO_TCP) + key := getKey(queryIP, queryPort, testdns.GetServerIP(t).String(), syscall.IPPROTO_TCP) var allStats StatsByKeyByNameByType require.Eventuallyf(t, func() bool { allStats = statKeeper.Snapshot() @@ -248,7 +246,7 @@ func TestDNSOverTCPSuccessfulResponseCountWithoutDomain(t *testing.T) { } func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) { - reverseDNS := initDNSTestsWithDomainCollection(t, false) + reverseDNS := initDNSTestsWithDomainCollection(t, true) defer reverseDNS.Close() statKeeper := reverseDNS.statKeeper domains := []string{ @@ -256,7 +254,8 @@ func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) { "google.com", "acm.org", } - queryIP, queryPort, reps := sendDNSQueries(t, domains, validDNSServerIP, "tcp") + serverIP := testdns.GetServerIP(t) + queryIP, queryPort, reps := sendDNSQueries(t, domains, serverIP, "tcp") // Check that all the queries succeeded for _, rep := range reps { @@ -265,7 +264,7 @@ func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) { } var allStats StatsByKeyByNameByType - key := getKey(queryIP, queryPort, validDNSServerIP, syscall.IPPROTO_TCP) + key := getKey(queryIP, queryPort, serverIP.String(), syscall.IPPROTO_TCP) require.Eventually(t, func() bool { allStats = statKeeper.Snapshot() return hasDomains(allStats[key], domains...) @@ -282,15 +281,6 @@ func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) { } } -type handler struct{} - -func (h *handler) ServeDNS(w mdns.ResponseWriter, r *mdns.Msg) { - msg := mdns.Msg{} - msg.SetReply(r) - msg.SetRcode(r, mdns.RcodeServerFailure) - _ = w.WriteMsg(&msg) -} - func TestDNSFailedResponseCount(t *testing.T) { reverseDNS := initDNSTestsWithDomainCollection(t, true) defer reverseDNS.Close() @@ -298,23 +288,17 @@ func TestDNSFailedResponseCount(t *testing.T) { domains := []string{ "nonexistenent.net.com", - "aabdgdfsgsdafsdafsad", + "missingdomain.com", } - queryIP, queryPort, reps := sendDNSQueries(t, domains, validDNSServerIP, "tcp") + queryIP, queryPort, reps := sendDNSQueries(t, domains, testdns.GetServerIP(t), "tcp") for _, rep := range reps { require.NotNil(t, rep) - require.NotEqual(t, rep.Rcode, mdns.RcodeSuccess) // All the queries should have failed + require.Equal(t, rep.Rcode, mdns.RcodeNameError) // All the queries should have failed } - key1 := getKey(queryIP, queryPort, validDNSServerIP, syscall.IPPROTO_TCP) - h := handler{} - shutdown, _ := newTestServer(t, localhost, 53, "udp", h.ServeDNS) - defer shutdown() - - queryIP, queryPort, _ = sendDNSQueries(t, domains, localhost, "udp") var allStats StatsByKeyByNameByType - // First check the one sent over TCP. Expected error type: NXDomain + key1 := getKey(queryIP, queryPort, testdns.GetServerIP(t).String(), syscall.IPPROTO_TCP) require.Eventually(t, func() bool { allStats = statKeeper.Snapshot() return hasDomains(allStats[key1], domains...) @@ -324,6 +308,16 @@ func TestDNSFailedResponseCount(t *testing.T) { assert.Equal(t, uint32(1), allStats[key1][ToHostname(d)][TypeA].CountByRcode[uint32(layers.DNSResponseCodeNXDomain)], "expected one NXDOMAIN for %s, got %v", d, allStats[key1][ToHostname(d)]) } + domains = []string{ + "failedserver.com", + "failedservertoo.com", + } + queryIP, queryPort, reps = sendDNSQueries(t, domains, net.ParseIP(localhost), "udp") + for _, rep := range reps { + require.NotNil(t, rep) + require.Equal(t, rep.Rcode, mdns.RcodeServerFailure) // All the queries should have failed + } + // Next check the one sent over UDP. Expected error type: ServFail key2 := getKey(queryIP, queryPort, localhost, syscall.IPPROTO_UDP) require.Eventually(t, func() bool { @@ -342,13 +336,12 @@ func TestDNSOverNonPort53(t *testing.T) { statKeeper := reverseDNS.statKeeper domains := []string{ - "nonexistent.com.net", + "nonexistent.net.com", } - h := &handler{} - shutdown, port := newTestServer(t, localhost, 0, "udp", h.ServeDNS) + shutdown, port := newTestServer(t, localhost, "udp") defer shutdown() - queryIP, queryPort, reps := sendDNSQueriesOnPort(t, domains, localhost, fmt.Sprintf("%d", port), "udp") + queryIP, queryPort, reps := sendDNSQueriesOnPort(t, domains, net.ParseIP(localhost), strconv.Itoa(int(port)), "udp") require.NotNil(t, reps[0]) // we only pick up on port 53 traffic, so we shouldn't ever get stats @@ -360,6 +353,39 @@ func TestDNSOverNonPort53(t *testing.T) { }, 3*time.Second, 10*time.Millisecond, "found DNS data for key %v when it should be missing", key) } +func newTestServer(t *testing.T, ip string, protocol string) (func(), uint16) { + t.Helper() + addr := net.JoinHostPort(ip, "0") + srv := &mdns.Server{ + Addr: addr, + Net: protocol, + Handler: mdns.HandlerFunc(func(w mdns.ResponseWriter, r *mdns.Msg) { + msg := mdns.Msg{} + msg.SetReply(r) + msg.SetRcode(r, mdns.RcodeServerFailure) + _ = w.WriteMsg(&msg) + }), + } + + initChan := make(chan error, 1) + srv.NotifyStartedFunc = func() { + initChan <- nil + } + go func() { + initChan <- srv.ListenAndServe() + close(initChan) + }() + + if err := <-initChan; err != nil { + t.Errorf("could not initialize DNS server: %s", err) + return func() {}, uint16(0) + } + + return func() { + _ = srv.Shutdown() + }, uint16(srv.PacketConn.LocalAddr().(*net.UDPAddr).Port) +} + func TestDNSOverUDPTimeoutCount(t *testing.T) { reverseDNS := initDNSTestsWithDomainCollection(t, false) defer reverseDNS.Close() @@ -367,7 +393,7 @@ func TestDNSOverUDPTimeoutCount(t *testing.T) { invalidServerIP := "8.8.8.90" domainQueried := "agafsdfsdasdfsd" - queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, invalidServerIP, "udp") + queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, net.ParseIP(invalidServerIP), "udp") require.Nil(t, reps[0]) var allStats StatsByKeyByNameByType @@ -389,7 +415,7 @@ func TestDNSOverUDPTimeoutCountWithoutDomain(t *testing.T) { invalidServerIP := "8.8.8.90" domainQueried := "agafsdfsdasdfsd" - queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, invalidServerIP, "udp") + queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, net.ParseIP(invalidServerIP), "udp") require.Nil(t, reps[0]) key := getKey(queryIP, queryPort, invalidServerIP, syscall.IPPROTO_UDP) @@ -429,23 +455,20 @@ func TestDNSOverIPv6(t *testing.T) { reverseDNS := initDNSTestsWithDomainCollection(t, true) defer reverseDNS.Close() statKeeper := reverseDNS.statKeeper + domain := "missingdomain.com" + serverIP := testdns.GetServerIP(t) - // This DNS server is set up so it always returns a NXDOMAIN answer - serverIP := net.IPv6loopback.String() - closeFn, _ := newTestServer(t, serverIP, 53, "udp", nxDomainHandler) - defer closeFn() - - queryIP, queryPort, reps := sendDNSQueries(t, []string{"nxdomain-123.com"}, serverIP, "udp") + queryIP, queryPort, reps := sendDNSQueries(t, []string{domain}, serverIP, "udp") require.NotNil(t, reps[0]) - key := getKey(queryIP, queryPort, serverIP, syscall.IPPROTO_UDP) + key := getKey(queryIP, queryPort, serverIP.String(), syscall.IPPROTO_UDP) var allStats StatsByKeyByNameByType require.Eventually(t, func() bool { allStats = statKeeper.Snapshot() return allStats[key] != nil }, 3*time.Second, 10*time.Millisecond, "missing DNS data for key %v", key) - stats := allStats[key][ToHostname("nxdomain-123.com")][TypeA] + stats := allStats[key][ToHostname(domain)][TypeA] assert.Equal(t, 1, len(stats.CountByRcode)) assert.Equal(t, uint32(1), stats.CountByRcode[uint32(layers.DNSResponseCodeNXDomain)]) } @@ -458,85 +481,26 @@ func TestDNSNestedCNAME(t *testing.T) { defer reverseDNS.Close() statKeeper := reverseDNS.statKeeper - serverIP := "127.0.0.1" - closeFn, _ := newTestServer(t, serverIP, 53, "udp", func(w dns.ResponseWriter, r *dns.Msg) { - answer := new(dns.Msg) - answer.SetReply(r) + domain := "nestedcname.com" - top := new(dns.CNAME) - top.Hdr = dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600} - top.Target = "www.example.com." + serverIP := testdns.GetServerIP(t) - nested := new(dns.CNAME) - nested.Hdr = dns.RR_Header{Name: "www.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600} - nested.Target = "www2.example.com." - - ip := new(dns.A) - ip.Hdr = dns.RR_Header{Name: "www2.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600} - ip.A = net.ParseIP("127.0.0.1") - - answer.Answer = append(answer.Answer, top, nested, ip) - answer.SetRcode(r, dns.RcodeSuccess) - _ = w.WriteMsg(answer) - }) - defer closeFn() - - queryIP, queryPort, reps := sendDNSQueries(t, []string{"example.com"}, serverIP, "udp") + queryIP, queryPort, reps := sendDNSQueries(t, []string{domain}, serverIP, "udp") require.NotNil(t, reps[0]) - key := getKey(queryIP, queryPort, serverIP, syscall.IPPROTO_UDP) + key := getKey(queryIP, queryPort, serverIP.String(), syscall.IPPROTO_UDP) + var allStats StatsByKeyByNameByType require.Eventually(t, func() bool { allStats = statKeeper.Snapshot() return allStats[key] != nil }, 3*time.Second, 10*time.Millisecond, "missing DNS data for key %v", key) - stats := allStats[key][ToHostname("example.com")][TypeA] + stats := allStats[key][ToHostname(domain)][TypeA] assert.Equal(t, 1, len(stats.CountByRcode)) assert.Equal(t, uint32(1), stats.CountByRcode[uint32(layers.DNSResponseCodeNoErr)]) - checkSnooping(t, serverIP, "example.com", reverseDNS) -} - -func newTestServer(t *testing.T, ip string, port uint16, protocol string, handler dns.HandlerFunc) (func(), uint16) { - addr := net.JoinHostPort(ip, strconv.Itoa(int(port))) - srv := &dns.Server{Addr: addr, Net: protocol, Handler: handler} - - initChan := make(chan error, 1) - srv.NotifyStartedFunc = func() { - initChan <- nil - } - - go func() { - initChan <- srv.ListenAndServe() - close(initChan) - }() - - if err := <-initChan; err != nil { - t.Errorf("could not initialize DNS server: %s", err) - return func() {}, port - } - - if port == 0 { - switch protocol { - case "udp": - port = uint16(srv.PacketConn.LocalAddr().(*net.UDPAddr).Port) - case "tcp": - port = uint16(srv.Listener.Addr().(*net.TCPAddr).Port) - } - } - - return func() { - _ = srv.Shutdown() - }, port -} - -// nxDomainHandler returns a NXDOMAIN response for any query -func nxDomainHandler(w dns.ResponseWriter, r *dns.Msg) { - answer := new(dns.Msg) - answer.SetReply(r) - answer.SetRcode(r, dns.RcodeNameError) - _ = w.WriteMsg(answer) + checkSnooping(t, serverIP.String(), domain, reverseDNS) } func testConfig() *config.Config { diff --git a/pkg/network/testutil/server.go b/pkg/network/testutil/server.go index 10222c4b2051c..900fba5f9c267 100644 --- a/pkg/network/testutil/server.go +++ b/pkg/network/testutil/server.go @@ -121,7 +121,6 @@ func StartServerUDP(t *testing.T, ip net.IP, port int) io.Closer { port, err = strconv.Atoi(portStr) assert.Nil(t, err) - require.NoError(t, err) go func() { close(ch) diff --git a/pkg/network/tracer/testutil/testdns/test_dns_server.go b/pkg/network/tracer/testutil/testdns/test_dns_server.go new file mode 100644 index 0000000000000..455d32d6a51a0 --- /dev/null +++ b/pkg/network/tracer/testutil/testdns/test_dns_server.go @@ -0,0 +1,127 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2023-present Datadog, Inc. + +//go:build test + +// Package testdns contains a DNS server for use in testing +package testdns + +import ( + "net" + "sync" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" +) + +var globalTCPError error +var globalUDPError error +var serverOnce sync.Once + +const localhostAddr = "127.0.0.1" + +// GetServerIP returns the IP address of the test DNS server. The test DNS server returns canned responses for several +// known domains that are used in integration tests. +// +// see server#start to see which domains are handled. +func GetServerIP(t *testing.T) net.IP { + var srv *server + serverOnce.Do(func() { + srv = newServer() + globalTCPError = srv.Start("tcp") + globalUDPError = srv.Start("udp") + }) + require.NoError(t, globalTCPError, "error starting local TCP DNS server") + require.NoError(t, globalUDPError, "error starting local UDP DNS server") + return net.ParseIP(localhostAddr) +} + +type server struct{} + +func newServer() *server { + return &server{} +} + +func (s *server) Start(transport string) error { + started := make(chan struct{}, 1) + errChan := make(chan error, 1) + address := localhostAddr + ":53" + srv := dns.Server{ + Addr: address, + Net: transport, + Handler: dns.HandlerFunc(func(writer dns.ResponseWriter, msg *dns.Msg) { + switch msg.Question[0].Name { + case "good.com.": + respond(msg, writer, "good.com. 30 IN A 10.0.0.1") + case "golang.org.": + respond(msg, writer, "golang.org. 30 IN A 10.0.0.2") + case "google.com.": + respond(msg, writer, "google.com. 30 IN A 10.0.0.3") + case "acm.org.": + respond(msg, writer, "acm.org. 30 IN A 10.0.0.4") + case "nonexistenent.net.com.": + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Rcode = dns.RcodeNameError + _ = writer.WriteMsg(resp) + case "missingdomain.com.": + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Rcode = dns.RcodeNameError + _ = writer.WriteMsg(resp) + case "nestedcname.com.": + resp := &dns.Msg{} + resp.SetReply(msg) + top := new(dns.CNAME) + top.Hdr = dns.RR_Header{Name: "nestedcname.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600} + top.Target = "www.nestedcname.com." + nested := new(dns.CNAME) + nested.Hdr = dns.RR_Header{Name: "www.nestedcname.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600} + nested.Target = "www2.nestedcname.com." + ip := new(dns.A) + ip.Hdr = dns.RR_Header{Name: "www2.nestedcname.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600} + ip.A = net.ParseIP(localhostAddr) + + resp.Answer = append(resp.Answer, top, nested, ip) + resp.SetRcode(msg, dns.RcodeSuccess) + _ = writer.WriteMsg(resp) + default: + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Rcode = dns.RcodeServerFailure + _ = writer.WriteMsg(resp) + } + }), + NotifyStartedFunc: func() { + started <- struct{}{} + }, + } + go func() { + err := srv.ListenAndServe() + if err != nil { + errChan <- err + } + }() + + select { + case <-started: + return nil + case err := <-errChan: + return err + } +} + +func respond(req *dns.Msg, writer dns.ResponseWriter, record string) { + resp := &dns.Msg{} + resp.SetReply(req) + + rr, err := dns.NewRR(record) + if err != nil { + panic(err) + } + resp.Answer = []dns.RR{rr} + _ = writer.WriteMsg(resp) +} diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index a65e017c0ce27..3ddaf05e00eb3 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -71,9 +71,9 @@ func doDNSQuery(t *testing.T, domain string, serverIP string) (*net.UDPAddr, *ne dnsClient := new(dns.Client) dnsConn, err := dnsClient.Dial(dnsServerAddr.String()) require.NoError(t, err) - defer dnsConn.Close() dnsClientAddr := dnsConn.LocalAddr().(*net.UDPAddr) _, _, err = dnsClient.ExchangeWithConn(queryMsg, dnsConn) + _ = dnsConn.Close() require.NoError(t, err) return dnsClientAddr, dnsServerAddr diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index 06304c334bad8..52840bdd4cb47 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -36,6 +36,7 @@ import ( "github.com/DataDog/datadog-agent/pkg/ebpf/ebpftest" "github.com/DataDog/datadog-agent/pkg/network" "github.com/DataDog/datadog-agent/pkg/network/config" + "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" "github.com/DataDog/datadog-agent/pkg/process/util" "github.com/DataDog/datadog-agent/pkg/util/log" ) @@ -1002,10 +1003,6 @@ func getConnections(t require.TestingT, tr *Tracer) *network.Connections { return connections } -const ( - validDNSServer = "8.8.8.8" -) - func testDNSStats(t *testing.T, tr *Tracer, domain string, success, failure, timeout int, serverIP string) { tr.removeClient(clientID) initTracerState(t, tr) @@ -1060,9 +1057,9 @@ func testDNSStats(t *testing.T, tr *Tracer, domain string, success, failure, tim failedResponses := total - successfulResponses // DNS Stats - assert.Equal(t, uint32(success), successfulResponses) + assert.Equal(t, uint32(success), successfulResponses, "expected %d successful responses but got %d", success, successfulResponses) assert.Equal(t, uint32(failure), failedResponses) - assert.Equal(t, uint32(timeout), timeouts) + assert.Equal(t, uint32(timeout), timeouts, "expected %d timeouts but got %d", timeout, timeouts) } func (s *TracerSuite) TestDNSStats() { @@ -1070,12 +1067,13 @@ func (s *TracerSuite) TestDNSStats() { cfg := testConfig() cfg.CollectDNSStats = true cfg.DNSTimeout = 1 * time.Second + cfg.CollectLocalDNS = true tr := setupTracer(t, cfg) t.Run("valid domain", func(t *testing.T) { - testDNSStats(t, tr, "golang.org", 1, 0, 0, validDNSServer) + testDNSStats(t, tr, "good.com", 1, 0, 0, testdns.GetServerIP(t).String()) }) t.Run("invalid domain", func(t *testing.T) { - testDNSStats(t, tr, "abcdedfg", 0, 1, 0, validDNSServer) + testDNSStats(t, tr, "abcdedfg", 0, 1, 0, testdns.GetServerIP(t).String()) }) t.Run("timeout", func(t *testing.T) { testDNSStats(t, tr, "golang.org", 0, 0, 1, "1.2.3.4")