Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate snooper tests to use new local DNS server #20615

Merged
merged 25 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions LICENSE-3rdparty.csv
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,7 @@ core,github.com/mdlayher/netlink/nlenc,MIT,Copyright (C) 2016-2022 Matt Layher
core,github.com/mdlayher/socket,MIT,Copyright (C) 2021 Matt Layher
core,github.com/mholt/archiver/v3,MIT,Copyright (c) 2016 Matthew Holt
core,github.com/microsoft/go-rustaudit,MIT,Copyright (c) Microsoft Corporation
core,github.com/miekg/dns,BSD-3-Clause,"Alex A. Skinner | Alex Sergeyev | Andrew Tunnell-Jones | Ask Bjørn Hansen | Copyright (c) 2009, The Go Authors. Extensions copyright (c) 2011, Miek Gieben | Copyright 2009 The Go Authors. All rights reserved. Use of this source code | Copyright 2011 Miek Gieben. All rights reserved. Use of this source code is | Copyright 2014 CloudFlare. All rights reserved. Use of this source code is | Dave Cheney | Dusty Wilson | James Hartig | Marek Majkowski | Miek Gieben <[email protected]> | Omri Bahumi | Peter van Dijk | copyright (c) 2011 Miek Gieben"
core,github.com/mitchellh/copystructure,MIT,Copyright (c) 2014 Mitchell Hashimoto
core,github.com/mitchellh/go-homedir,MIT,Copyright (c) 2013 Mitchell Hashimoto
core,github.com/mitchellh/hashstructure/v2,MIT,Copyright (c) 2016 Mitchell Hashimoto
Expand Down Expand Up @@ -1841,6 +1842,7 @@ core,golang.org/x/net/internal/socket,BSD-3-Clause,Copyright (c) 2009 The Go Aut
core,golang.org/x/net/internal/socks,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
core,golang.org/x/net/internal/timeseries,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
core,golang.org/x/net/ipv4,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
core,golang.org/x/net/ipv6,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
core,golang.org/x/net/proxy,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
core,golang.org/x/net/trace,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
core,golang.org/x/net/websocket,BSD-3-Clause,Copyright (c) 2009 The Go Authors. All rights reserved
Expand Down
173 changes: 53 additions & 120 deletions pkg/network/dns/snooper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@
package dns

import (
"fmt"
"net"
"strconv"
"syscall"
"testing"
"time"

"github.com/google/gopacket/layers"
"github.com/miekg/dns"
mdns "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/DataDog/datadog-agent/pkg/network/config"
"github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns"
"github.com/DataDog/datadog-agent/pkg/process/util"
)

Expand Down Expand Up @@ -59,7 +58,7 @@ func TestDNSOverUDPSnooping(t *testing.T) {
defer reverseDNS.Close()

// Connect to golang.org. This will result in a DNS lookup which will be captured by socketFilterSnooper
_, _, reps := sendDNSQueries(t, []string{"golang.org"}, validDNSServerIP, "udp")
_, _, reps := sendDNSQueries(t, []string{"golang.org"}, testdns.GetServerIPPort53(t), "udp")
rep := reps[0]
require.NotNil(t, rep)
require.Equal(t, rep.Rcode, mdns.RcodeSuccess)
Expand All @@ -80,7 +79,7 @@ func TestDNSOverTCPSnooping(t *testing.T) {
reverseDNS := initDNSTestsWithDomainCollection(t, false)
defer reverseDNS.Close()

_, _, reps := sendDNSQueries(t, []string{"golang.org"}, validDNSServerIP, "tcp")
_, _, reps := sendDNSQueries(t, []string{"golang.org"}, testdns.GetServerIPPort53(t), "tcp")
rep := reps[0]
require.NotNil(t, rep)
require.Equal(t, rep.Rcode, mdns.RcodeSuccess)
Expand All @@ -95,20 +94,19 @@ func TestDNSOverTCPSnooping(t *testing.T) {
}

// Get the preferred outbound IP of this machine
func getOutboundIP(t *testing.T, serverIP string) net.IP {
if parsedIP := net.ParseIP(serverIP); parsedIP.IsLoopback() {
return parsedIP
func getOutboundIP(t *testing.T, serverIP net.IP) net.IP {
if serverIP.IsLoopback() {
return serverIP
}
conn, err := net.Dial("udp", serverIP+":80")
conn, err := net.Dial("udp", serverIP.String()+":80")
require.NoError(t, err)
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
}

const (
localhost = "127.0.0.1"
validDNSServerIP = "8.8.8.8"
localhost = "127.0.0.1"
)

func initDNSTestsWithDomainCollection(t *testing.T, localDNS bool) *dnsMonitor {
Expand All @@ -131,13 +129,13 @@ func initDNSTests(t *testing.T, localDNS bool, collectDomain bool) *dnsMonitor {
func sendDNSQueries(
t *testing.T,
domains []string,
serverIP string,
serverIP net.IP,
protocol string,
) (string, int, []*mdns.Msg) {
return sendDNSQueriesOnPort(t, domains, serverIP, "53", protocol)
}

func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP string, port string, protocol string) (string, int, []*mdns.Msg) {
func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP net.IP, port string, protocol string) (string, int, []*mdns.Msg) {
// Create a DNS query message
msg := new(mdns.Msg)
msg.RecursionDesired = true
Expand All @@ -156,18 +154,18 @@ func sendDNSQueriesOnPort(t *testing.T, domains []string, serverIP string, port
}

dnsClient := mdns.Client{Net: protocol, Dialer: localAddrDialer}
dnsHost := net.JoinHostPort(serverIP, port)
dnsHost := net.JoinHostPort(serverIP.String(), port)
conn, err := dnsClient.Dial(dnsHost)
require.NoError(t, err)

var reps []*mdns.Msg
var queryPort int
if protocol == "tcp" {
queryPort = conn.Conn.(*net.TCPConn).LocalAddr().(*net.TCPAddr).Port
} else { // UDP
queryPort = conn.Conn.(*net.UDPConn).LocalAddr().(*net.UDPAddr).Port
}

var reps []*mdns.Msg
for _, domain := range domains {
msg.SetQuestion(mdns.Fqdn(domain), mdns.TypeA)
rep, _, _ := dnsClient.ExchangeWithConn(msg, conn)
Expand Down Expand Up @@ -215,23 +213,23 @@ func countDNSResponses(statsByDomain map[Hostname]map[QueryType]Stats) int {
}

func TestDNSOverTCPSuccessfulResponseCountWithoutDomain(t *testing.T) {
reverseDNS := initDNSTests(t, false, false)
reverseDNS := initDNSTests(t, true, false)
defer reverseDNS.Close()
statKeeper := reverseDNS.statKeeper
domains := []string{
"golang.org",
"google.com",
"acm.org",
}
queryIP, queryPort, reps := sendDNSQueries(t, domains, validDNSServerIP, "tcp")
queryIP, queryPort, reps := sendDNSQueries(t, domains, testdns.GetServerIPPort53(t), "tcp")

// Check that all the queries succeeded
for _, rep := range reps {
require.NotNil(t, rep)
require.Equal(t, rep.Rcode, mdns.RcodeSuccess)
}

key := getKey(queryIP, queryPort, validDNSServerIP, syscall.IPPROTO_TCP)
key := getKey(queryIP, queryPort, testdns.GetServerIPPort53(t).String(), syscall.IPPROTO_TCP)
var allStats StatsByKeyByNameByType
require.Eventuallyf(t, func() bool {
allStats = statKeeper.Snapshot()
Expand All @@ -248,15 +246,16 @@ func TestDNSOverTCPSuccessfulResponseCountWithoutDomain(t *testing.T) {
}

func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) {
reverseDNS := initDNSTestsWithDomainCollection(t, false)
reverseDNS := initDNSTestsWithDomainCollection(t, true)
defer reverseDNS.Close()
statKeeper := reverseDNS.statKeeper
domains := []string{
"golang.org",
"google.com",
"acm.org",
}
queryIP, queryPort, reps := sendDNSQueries(t, domains, validDNSServerIP, "tcp")
serverIP := testdns.GetServerIPPort53(t)
queryIP, queryPort, reps := sendDNSQueries(t, domains, serverIP, "tcp")

// Check that all the queries succeeded
for _, rep := range reps {
Expand All @@ -265,7 +264,7 @@ func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) {
}

var allStats StatsByKeyByNameByType
key := getKey(queryIP, queryPort, validDNSServerIP, syscall.IPPROTO_TCP)
key := getKey(queryIP, queryPort, serverIP.String(), syscall.IPPROTO_TCP)
require.Eventually(t, func() bool {
allStats = statKeeper.Snapshot()
return hasDomains(allStats[key], domains...)
Expand All @@ -282,39 +281,24 @@ func TestDNSOverTCPSuccessfulResponseCount(t *testing.T) {
}
}

type handler struct{}

func (h *handler) ServeDNS(w mdns.ResponseWriter, r *mdns.Msg) {
msg := mdns.Msg{}
msg.SetReply(r)
msg.SetRcode(r, mdns.RcodeServerFailure)
_ = w.WriteMsg(&msg)
}

func TestDNSFailedResponseCount(t *testing.T) {
reverseDNS := initDNSTestsWithDomainCollection(t, true)
defer reverseDNS.Close()
statKeeper := reverseDNS.statKeeper

domains := []string{
"nonexistenent.net.com",
"aabdgdfsgsdafsdafsad",
"missingdomain.com",
}
queryIP, queryPort, reps := sendDNSQueries(t, domains, validDNSServerIP, "tcp")
queryIP, queryPort, reps := sendDNSQueries(t, domains, testdns.GetServerIPPort53(t), "tcp")
for _, rep := range reps {
require.NotNil(t, rep)
require.NotEqual(t, rep.Rcode, mdns.RcodeSuccess) // All the queries should have failed
require.Equal(t, rep.Rcode, mdns.RcodeNameError) // All the queries should have failed
}
key1 := getKey(queryIP, queryPort, validDNSServerIP, syscall.IPPROTO_TCP)

h := handler{}
shutdown, _ := newTestServer(t, localhost, 53, "udp", h.ServeDNS)
defer shutdown()

queryIP, queryPort, _ = sendDNSQueries(t, domains, localhost, "udp")
var allStats StatsByKeyByNameByType

// First check the one sent over TCP. Expected error type: NXDomain
key1 := getKey(queryIP, queryPort, testdns.GetServerIPPort53(t).String(), syscall.IPPROTO_TCP)
require.Eventually(t, func() bool {
allStats = statKeeper.Snapshot()
return hasDomains(allStats[key1], domains...)
Expand All @@ -324,6 +308,16 @@ func TestDNSFailedResponseCount(t *testing.T) {
assert.Equal(t, uint32(1), allStats[key1][ToHostname(d)][TypeA].CountByRcode[uint32(layers.DNSResponseCodeNXDomain)], "expected one NXDOMAIN for %s, got %v", d, allStats[key1][ToHostname(d)])
}

domains = []string{
"failedserver.com",
"failedservertoo.com",
}
queryIP, queryPort, reps = sendDNSQueries(t, domains, net.ParseIP(localhost), "udp")
for _, rep := range reps {
require.NotNil(t, rep)
require.Equal(t, rep.Rcode, mdns.RcodeServerFailure) // All the queries should have failed
}

// Next check the one sent over UDP. Expected error type: ServFail
key2 := getKey(queryIP, queryPort, localhost, syscall.IPPROTO_UDP)
require.Eventually(t, func() bool {
Expand All @@ -342,13 +336,14 @@ func TestDNSOverNonPort53(t *testing.T) {
statKeeper := reverseDNS.statKeeper

domains := []string{
"nonexistent.com.net",
"nonexistent.net.com",
}
h := &handler{}
shutdown, port := newTestServer(t, localhost, 0, "udp", h.ServeDNS)
defer shutdown()
ln, _ := net.Listen("tcp", "127.0.0.1:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the cases where we need a random port, I think we should just spin up a server and shut it down in the test.

We can also pass in 0 for the port for the server spin up code to pick a port.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is what we were doing before... I think the idea here was to centralize the server code. but I was on the fence when doing this, wdyt @leeavital

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings, if we prefer to have a server-per-test, I'm fine with it. My thinking was it was nice to have a single global test server because:

  • tests run faster when they don't each have to spin up a server
  • tests don't have to deal with managing the lifecycle of
  • the nature of a stubbed server makes having a global easy
  • (obsolete) it was unlikely for the global server to conflict with other test fixtures since it listened on a nonsense ip (10.10.10.10)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was only proposing a local server for this test where we don't use port 53. The other tests can use the global server. In addition, the server creation/run code seems to be using the same var, globalServer, for both servers, which is likely going to be brittle.

port := ln.Addr().(*net.TCPAddr).Port
_ = ln.Close()
serverIP := testdns.GetServerIP(t, port)

queryIP, queryPort, reps := sendDNSQueriesOnPort(t, domains, localhost, fmt.Sprintf("%d", port), "udp")
queryIP, queryPort, reps := sendDNSQueriesOnPort(t, domains, serverIP, strconv.Itoa(port), "udp")
require.NotNil(t, reps[0])

// we only pick up on port 53 traffic, so we shouldn't ever get stats
Expand All @@ -367,7 +362,7 @@ func TestDNSOverUDPTimeoutCount(t *testing.T) {

invalidServerIP := "8.8.8.90"
domainQueried := "agafsdfsdasdfsd"
queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, invalidServerIP, "udp")
queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, net.ParseIP(invalidServerIP), "udp")
require.Nil(t, reps[0])

var allStats StatsByKeyByNameByType
Expand All @@ -389,7 +384,7 @@ func TestDNSOverUDPTimeoutCountWithoutDomain(t *testing.T) {

invalidServerIP := "8.8.8.90"
domainQueried := "agafsdfsdasdfsd"
queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, invalidServerIP, "udp")
queryIP, queryPort, reps := sendDNSQueries(t, []string{domainQueried}, net.ParseIP(invalidServerIP), "udp")
require.Nil(t, reps[0])

key := getKey(queryIP, queryPort, invalidServerIP, syscall.IPPROTO_UDP)
Expand Down Expand Up @@ -429,23 +424,20 @@ func TestDNSOverIPv6(t *testing.T) {
reverseDNS := initDNSTestsWithDomainCollection(t, true)
defer reverseDNS.Close()
statKeeper := reverseDNS.statKeeper
domain := "missingdomain.com"
serverIP := testdns.GetServerIPPort53(t)

// This DNS server is set up so it always returns a NXDOMAIN answer
serverIP := net.IPv6loopback.String()
closeFn, _ := newTestServer(t, serverIP, 53, "udp", nxDomainHandler)
defer closeFn()

queryIP, queryPort, reps := sendDNSQueries(t, []string{"nxdomain-123.com"}, serverIP, "udp")
queryIP, queryPort, reps := sendDNSQueries(t, []string{domain}, serverIP, "udp")
require.NotNil(t, reps[0])

key := getKey(queryIP, queryPort, serverIP, syscall.IPPROTO_UDP)
key := getKey(queryIP, queryPort, serverIP.String(), syscall.IPPROTO_UDP)
var allStats StatsByKeyByNameByType
require.Eventually(t, func() bool {
allStats = statKeeper.Snapshot()
return allStats[key] != nil
}, 3*time.Second, 10*time.Millisecond, "missing DNS data for key %v", key)

stats := allStats[key][ToHostname("nxdomain-123.com")][TypeA]
stats := allStats[key][ToHostname(domain)][TypeA]
assert.Equal(t, 1, len(stats.CountByRcode))
assert.Equal(t, uint32(1), stats.CountByRcode[uint32(layers.DNSResponseCodeNXDomain)])
}
Expand All @@ -458,85 +450,26 @@ func TestDNSNestedCNAME(t *testing.T) {
defer reverseDNS.Close()
statKeeper := reverseDNS.statKeeper

serverIP := "127.0.0.1"
closeFn, _ := newTestServer(t, serverIP, 53, "udp", func(w dns.ResponseWriter, r *dns.Msg) {
answer := new(dns.Msg)
answer.SetReply(r)

top := new(dns.CNAME)
top.Hdr = dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600}
top.Target = "www.example.com."
domain := "nestedcname.com"

nested := new(dns.CNAME)
nested.Hdr = dns.RR_Header{Name: "www.example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600}
nested.Target = "www2.example.com."
serverIP := testdns.GetServerIPPort53(t)

ip := new(dns.A)
ip.Hdr = dns.RR_Header{Name: "www2.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600}
ip.A = net.ParseIP("127.0.0.1")

answer.Answer = append(answer.Answer, top, nested, ip)
answer.SetRcode(r, dns.RcodeSuccess)
_ = w.WriteMsg(answer)
})
defer closeFn()

queryIP, queryPort, reps := sendDNSQueries(t, []string{"example.com"}, serverIP, "udp")
queryIP, queryPort, reps := sendDNSQueries(t, []string{domain}, serverIP, "udp")
require.NotNil(t, reps[0])

key := getKey(queryIP, queryPort, serverIP, syscall.IPPROTO_UDP)
key := getKey(queryIP, queryPort, serverIP.String(), syscall.IPPROTO_UDP)

var allStats StatsByKeyByNameByType
require.Eventually(t, func() bool {
allStats = statKeeper.Snapshot()
return allStats[key] != nil
}, 3*time.Second, 10*time.Millisecond, "missing DNS data for key %v", key)

stats := allStats[key][ToHostname("example.com")][TypeA]
stats := allStats[key][ToHostname(domain)][TypeA]
assert.Equal(t, 1, len(stats.CountByRcode))
assert.Equal(t, uint32(1), stats.CountByRcode[uint32(layers.DNSResponseCodeNoErr)])

checkSnooping(t, serverIP, "example.com", reverseDNS)
}

func newTestServer(t *testing.T, ip string, port uint16, protocol string, handler dns.HandlerFunc) (func(), uint16) {
addr := net.JoinHostPort(ip, strconv.Itoa(int(port)))
srv := &dns.Server{Addr: addr, Net: protocol, Handler: handler}

initChan := make(chan error, 1)
srv.NotifyStartedFunc = func() {
initChan <- nil
}

go func() {
initChan <- srv.ListenAndServe()
close(initChan)
}()

if err := <-initChan; err != nil {
t.Errorf("could not initialize DNS server: %s", err)
return func() {}, port
}

if port == 0 {
switch protocol {
case "udp":
port = uint16(srv.PacketConn.LocalAddr().(*net.UDPAddr).Port)
case "tcp":
port = uint16(srv.Listener.Addr().(*net.TCPAddr).Port)
}
}

return func() {
_ = srv.Shutdown()
}, port
}

// nxDomainHandler returns a NXDOMAIN response for any query
func nxDomainHandler(w dns.ResponseWriter, r *dns.Msg) {
answer := new(dns.Msg)
answer.SetReply(r)
answer.SetRcode(r, dns.RcodeNameError)
_ = w.WriteMsg(answer)
checkSnooping(t, serverIP.String(), domain, reverseDNS)
}

func testConfig() *config.Config {
Expand Down
Loading