Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform a binary search to find optimal size of DNS responses #4037

Merged
47 changes: 45 additions & 2 deletions agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,37 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) {
resp.Extra = extra
}

// dnsBinaryTruncate find the optimal number of records using a fast binary search and return
// it in order to return a DNS answer lower than maxSize parameter.
func dnsBinaryTruncate(resp *dns.Msg, maxSize int, index map[string]dns.RR, hasExtra bool) int {
originalAnswser := resp.Answer
originalExtra := resp.Extra
originalIndex := index
startIndex := 0
endIndex := len(resp.Answer) + 1
for endIndex-startIndex > 1 {
median := startIndex + (endIndex-startIndex)/2

resp.Answer = originalAnswser[:median]
if hasExtra {
resp.Extra = originalExtra[:median]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if you think otherwise but I think lines 733 and 734 here are unnecessary as well as the originalIndex and originalExtra variables.

The hasExtra block should just be a call to syncExtra. slicing originalExtra should not do anything (other than eat a few cpu cycles) as syncExtra will just overwrite resp.Extra with a brand new slice of its making. Also since syncExtra doesn't modify the RR index there shouldn't be a need to hold the originalIndex and reset index to it.

So I would remove those two vars and those extra lines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mkeeler Yes, you are right, I was confused by a unit test at a given time, but since no modification is performed, it is safe. Done in next patch

index := originalIndex
syncExtra(index, resp)
}
aLen := resp.Len()
if aLen <= maxSize {
if maxSize-aLen < 10 {
// We are good, increasing will go out of bounds
return median
}
startIndex = median
} else {
endIndex = median
}
}
return startIndex
}

// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit
// of DNS responses
func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
Expand Down Expand Up @@ -752,7 +783,13 @@ func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) {
// This enforces the given limit on 64k, the max limit for DNS messages
for len(resp.Answer) > 0 && resp.Len() > maxSize {
truncated = true
resp.Answer = resp.Answer[:len(resp.Answer)-1]
// More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
resp.Answer = resp.Answer[:bestIndex]
} else {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
}
if hasExtra {
syncExtra(index, resp)
}
Expand Down Expand Up @@ -809,7 +846,13 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) {
compress := resp.Compress
resp.Compress = false
for len(resp.Answer) > 0 && resp.Len() > maxSize {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
// More than 100 bytes, find with a binary search
if resp.Len()-maxSize > 100 {
bestIndex := dnsBinaryTruncate(resp, maxSize, index, hasExtra)
resp.Answer = resp.Answer[:bestIndex]
} else {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
}
if hasExtra {
syncExtra(index, resp)
}
Expand Down
60 changes: 49 additions & 11 deletions agent/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2983,6 +2983,44 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) {
}
}

func TestBinarySearch(t *testing.T) {
t.Parallel()
msgSrc := new(dns.Msg)
msgSrc.Compress = true
msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV)

for i := 0; i < 50; i++ {
target := fmt.Sprintf("host-redis-%d-%d.test.acme.com.node.dc1.consul.", i/256, i%256)
msgSrc.Answer = append(msgSrc.Answer, &dns.SRV{Hdr: dns.RR_Header{Name: "redis.service.consul.", Class: 1, Rrtype: dns.TypeSRV, Ttl: 0x3c}, Port: 0x4c57, Target: target})
msgSrc.Extra = append(msgSrc.Extra, &dns.CNAME{Hdr: dns.RR_Header{Name: target, Class: 1, Rrtype: dns.TypeCNAME, Ttl: 0x3c}, Target: fmt.Sprintf("fx.168.%d.%d.", i/256, i%256)})
}
for idx, maxSize := range []int{12, 256, 512, 8192, 65535} {
t.Run(fmt.Sprintf("binarySearch %d", maxSize), func(t *testing.T) {
msg := new(dns.Msg)
msgSrc.Compress = true
msgSrc.SetQuestion("redis.service.consul.", dns.TypeSRV)
msg.Answer = msgSrc.Answer
msg.Extra = msgSrc.Extra
index := make(map[string]dns.RR, len(msg.Extra))
indexRRs(msg.Extra, index)
blen := dnsBinaryTruncate(msg, maxSize, index, true)
msg.Answer = msg.Answer[:blen]
syncExtra(index, msg)
predicted := msg.Len()
buf, err := msg.Pack()
if err != nil {
t.Error(err)
}
if predicted < len(buf) {
t.Fatalf("Bug in DNS library: %d != %d", predicted, len(buf))
}
if len(buf) > maxSize || (idx != 0 && len(buf) < 16) || (maxSize == 65535 && blen != 50) {
t.Fatalf("bad[%d]: %d > %d", idx, len(buf), maxSize)
}
})
}
}

func TestDNS_TCP_and_UDP_Truncate(t *testing.T) {
t.Parallel()
a := NewTestAgent(t.Name(), `
Expand Down Expand Up @@ -3311,17 +3349,17 @@ func testDNSServiceLookupResponseLimits(t *testing.T, answerLimit int, qType uin
case 0:
if (expectedService > 0 && len(in.Answer) != expectedService) ||
(expectedService < -1 && len(in.Answer) < lib.AbsInt(expectedService)) {
return false, fmt.Errorf("%d/%d answers received for type %v for %s", len(in.Answer), answerLimit, qType, question)
return false, fmt.Errorf("%d/%d answers received for type %v for %s, sz:=%d", len(in.Answer), answerLimit, qType, question, in.Len())
}
case 1:
if (expectedQuery > 0 && len(in.Answer) != expectedQuery) ||
(expectedQuery < -1 && len(in.Answer) < lib.AbsInt(expectedQuery)) {
return false, fmt.Errorf("%d/%d answers received for type %v for %s", len(in.Answer), answerLimit, qType, question)
return false, fmt.Errorf("%d/%d answers received for type %v for %s, sz:=%d", len(in.Answer), answerLimit, qType, question, in.Len())
}
case 2:
if (expectedQueryID > 0 && len(in.Answer) != expectedQueryID) ||
(expectedQueryID < -1 && len(in.Answer) < lib.AbsInt(expectedQueryID)) {
return false, fmt.Errorf("%d/%d answers received for type %v for %s", len(in.Answer), answerLimit, qType, question)
return false, fmt.Errorf("%d/%d answers received for type %v for %s, sz:=%d", len(in.Answer), answerLimit, qType, question, in.Len())
}
default:
panic("abort")
Expand Down Expand Up @@ -3473,7 +3511,7 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) {
t.Parallel()
err := checkDNSService(t, test.numNodesTotal, test.aRecordLimit, qType, test.expectedAResults, test.udpSize, test.udpAnswerLimit)
if err != nil {
t.Errorf("Expected lookup %s to pass: %v", test.name, err)
t.Fatalf("Expected lookup %s to pass: %v", test.name, err)
}
})
}
Expand All @@ -3482,7 +3520,7 @@ func TestDNS_ServiceLookup_ARecordLimits(t *testing.T) {
t.Parallel()
err := checkDNSService(t, test.expectedSRVResults, test.aRecordLimit, dns.TypeSRV, test.numNodesTotal, test.udpSize, test.udpAnswerLimit)
if err != nil {
t.Errorf("Expected service SRV lookup %s to pass: %v", test.name, err)
t.Fatalf("Expected service SRV lookup %s to pass: %v", test.name, err)
}
})
}
Expand Down Expand Up @@ -3528,27 +3566,27 @@ func TestDNS_ServiceLookup_AnswerLimits(t *testing.T) {
}
for _, test := range tests {
test := test // capture loop var
t.Run("A lookup", func(t *testing.T) {
t.Run(fmt.Sprintf("A lookup %v", test), func(t *testing.T) {
t.Parallel()
ok, err := testDNSServiceLookupResponseLimits(t, test.udpAnswerLimit, dns.TypeA, test.expectedAService, test.expectedAQuery, test.expectedAQueryID)
if !ok {
t.Errorf("Expected service A lookup %s to pass: %v", test.name, err)
t.Fatalf("Expected service A lookup %s to pass: %v", test.name, err)
}
})

t.Run("AAAA lookup", func(t *testing.T) {
t.Run(fmt.Sprintf("AAAA lookup %v", test), func(t *testing.T) {
t.Parallel()
ok, err := testDNSServiceLookupResponseLimits(t, test.udpAnswerLimit, dns.TypeAAAA, test.expectedAAAAService, test.expectedAAAAQuery, test.expectedAAAAQueryID)
if !ok {
t.Errorf("Expected service AAAA lookup %s to pass: %v", test.name, err)
t.Fatalf("Expected service AAAA lookup %s to pass: %v", test.name, err)
}
})

t.Run("ANY lookup", func(t *testing.T) {
t.Run(fmt.Sprintf("ANY lookup %v", test), func(t *testing.T) {
t.Parallel()
ok, err := testDNSServiceLookupResponseLimits(t, test.udpAnswerLimit, dns.TypeANY, test.expectedANYService, test.expectedANYQuery, test.expectedANYQueryID)
if !ok {
t.Errorf("Expected service ANY lookup %s to pass: %v", test.name, err)
t.Fatalf("Expected service ANY lookup %s to pass: %v", test.name, err)
}
})
}
Expand Down