Skip to content

Commit

Permalink
Merge pull request #32 from bookmoons/bookmoons/ttl
Browse files Browse the repository at this point in the history
Respect TTL
  • Loading branch information
ydnar authored Jun 29, 2019
2 parents 92eaed9 + 5dd4216 commit bfb9c6f
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 26 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func main() {
}
```

Or construct with `dnsr.NewExpiring()` to expire cache entries based on TTL.

[Documentation](https://godoc.org/github.com/domainr/dnsr)

## Development
Expand Down
36 changes: 28 additions & 8 deletions cache.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package dnsr

import "sync"
import (
"sync"
"time"
)

type cache struct {
capacity int
expire bool
m sync.RWMutex
entries map[string]entry
}
Expand All @@ -14,13 +18,14 @@ const MinCacheCapacity = 1000

// newCache initializes and returns a new cache instance.
// Cache capacity defaults to MinCacheCapacity if <= 0.
func newCache(capacity int) *cache {
func newCache(capacity int, expire bool) *cache {
if capacity <= 0 {
capacity = MinCacheCapacity
}
return &cache{
capacity: capacity,
entries: make(map[string]entry),
expire: expire,
}
}

Expand Down Expand Up @@ -91,11 +96,26 @@ func (c *cache) get(qname string) RRs {
if len(e) == 0 {
return emptyRRs
}
i := 0
rrs := make(RRs, len(e))
for rr, _ := range e {
rrs[i] = rr
i++
if c.expire {
i := 0
rrs := make(RRs, len(e))
now := time.Now()
for rr, _ := range e {
if !rr.Expiry.IsZero() && now.After(rr.Expiry) {
delete(e, rr)
} else {
rrs[i] = rr
i++
}
}
return rrs[:i]
} else {
i := 0
rrs := make(RRs, len(e))
for rr, _ := range e {
rrs[i] = rr
i++
}
return rrs
}
return rrs
}
23 changes: 22 additions & 1 deletion cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,36 @@ package dnsr

import (
"testing"
"time"

"github.com/nbio/st"
)

func TestCache(t *testing.T) {
c := newCache(100)
c := newCache(100, false)
c.addNX("hello.")
rr := RR{Name: "hello.", Type: "A", Value: "1.2.3.4"}
c.add("hello.", rr)
rrs := c.get("hello.")
st.Expect(t, len(rrs), 1)
}

func TestLiveCacheEntry(t *testing.T) {
c := newCache(100, true)
c.addNX("alive.")
alive := time.Now().Add(time.Minute)
rr := RR{Name: "alive.", Type: "A", Value: "1.2.3.4", Expiry: alive}
c.add("alive.", rr)
rrs := c.get("alive.")
st.Expect(t, len(rrs), 1)
}

func TestExpiredCacheEntry(t *testing.T) {
c := newCache(100, true)
c.addNX("expired.")
expired := time.Now().Add(-time.Minute)
rr := RR{Name: "expired.", Type: "A", Value: "1.2.3.4", Expiry: expired}
c.add("expired.", rr)
rrs := c.get("expired.")
st.Expect(t, len(rrs), 0)
}
23 changes: 20 additions & 3 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
// Resolver implements a primitive, non-recursive, caching DNS resolver.
type Resolver struct {
cache *cache
expire bool
timeout time.Duration
}

Expand All @@ -43,7 +44,23 @@ func New(capacity int) *Resolver {
// NewWithTimeout initializes a Resolver with the specified cache size and resolution timeout.
func NewWithTimeout(capacity int, timeout time.Duration) *Resolver {
r := &Resolver{
cache: newCache(capacity),
cache: newCache(capacity, false),
expire: false,
timeout: timeout,
}
return r
}

// NewExpiring initializes an expiring Resolver with the specified cache size.
func NewExpiring(capacity int) *Resolver {
return NewExpiringWithTimeout(capacity, Timeout)
}

// NewExpiringWithTimeout initializes an expiring Resolved with the specified cache size and resolution timeout.
func NewExpiringWithTimeout(capacity int, timeout time.Duration) *Resolver {
r := &Resolver{
cache: newCache(capacity, true),
expire: true,
timeout: timeout,
}
return r
Expand Down Expand Up @@ -243,7 +260,7 @@ func (r *Resolver) exchange(ctx context.Context, host, qname, qtype string, dept
var hasSOA bool
if qtype == "NS" {
for _, drr := range rmsg.Ns {
rr, ok := convertRR(drr)
rr, ok := convertRR(drr, r.expire)
if !ok {
continue
}
Expand Down Expand Up @@ -293,7 +310,7 @@ func (r *Resolver) saveDNSRR(host, qname string, drrs []dns.RR) RRs {
var rrs RRs
cl := dns.CountLabel(qname)
for _, drr := range drrs {
rr, ok := convertRR(drr)
rr, ok := convertRR(drr, r.expire)
if !ok {
continue
}
Expand Down
9 changes: 9 additions & 0 deletions resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ func TestBazCoUKAny(t *testing.T) {
st.Expect(t, count(rrs, func(rr RR) bool { return rr.Type == "NS" }) >= 2, true)
}

func TestTTL(t *testing.T) {
r := NewExpiring(0)
rrs, err := r.ResolveErr("google.com", "A")
st.Expect(t, err, nil)
st.Expect(t, len(rrs) >= 4, true)
rr := rrs[0]
st.Expect(t, !rr.Expiry.IsZero(), true)
}

var testResolver *Resolver

func BenchmarkResolve(b *testing.B) {
Expand Down
4 changes: 2 additions & 2 deletions root_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ var (
)

func init() {
rootCache = newCache(strings.Count(root, "\n"))
rootCache = newCache(strings.Count(root, "\n"), false)
for t := range dns.ParseZone(strings.NewReader(root), "", "") {
if t.Error != nil {
continue
}
rr, ok := convertRR(t.RR)
rr, ok := convertRR(t.RR, false)
if ok {
rootCache.add(rr.Name, rr)
}
Expand Down
51 changes: 39 additions & 12 deletions rr.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package dnsr

import (
"fmt"
"strings"
"time"

"github.com/miekg/dns"
)

// RR represents a DNS resource record.
type RR struct {
Name string
Type string
Value string
Name string
Type string
Value string
TTL time.Duration
Expiry time.Time
}

// RRs represents a slice of DNS resource records.
Expand All @@ -30,32 +34,55 @@ const NameCollision = "127.0.53.53"

// String returns a string representation of an RR in zone-file format.
func (rr *RR) String() string {
return rr.Name + "\t 3600\tIN\t" + rr.Type + "\t" + rr.Value
if rr.Expiry.IsZero() {
return rr.Name + "\t 3600\tIN\t" + rr.Type + "\t" + rr.Value
} else {
ttl := ttlString(rr.TTL)
return rr.Name + "\t" + ttl + "\t" + rr.Type + "\t" + rr.Value
}
}

// ttlString constructs the TTL field of an RR string.
func ttlString(ttl time.Duration) string {
seconds := int(ttl.Seconds())
return fmt.Sprintf("%10d", seconds)
}

// convertRR converts a dns.RR to an RR.
// If the RR is not a type that this package uses,
// It will attempt to translate this if there are enough parameters
// Should all translation fail, it returns an undefined RR and false.
func convertRR(drr dns.RR) (RR, bool) {
func convertRR(drr dns.RR, expire bool) (RR, bool) {
var ttl time.Duration
var expiry time.Time
if expire {
ttl, expiry = calculateExpiry(drr)
}
switch t := drr.(type) {
case *dns.SOA:
return RR{toLowerFQDN(t.Hdr.Name), "SOA", toLowerFQDN(t.Ns)}, true
return RR{toLowerFQDN(t.Hdr.Name), "SOA", toLowerFQDN(t.Ns), ttl, expiry}, true
case *dns.NS:
return RR{toLowerFQDN(t.Hdr.Name), "NS", toLowerFQDN(t.Ns)}, true
return RR{toLowerFQDN(t.Hdr.Name), "NS", toLowerFQDN(t.Ns), ttl, expiry}, true
case *dns.CNAME:
return RR{toLowerFQDN(t.Hdr.Name), "CNAME", toLowerFQDN(t.Target)}, true
return RR{toLowerFQDN(t.Hdr.Name), "CNAME", toLowerFQDN(t.Target), ttl, expiry}, true
case *dns.A:
return RR{toLowerFQDN(t.Hdr.Name), "A", t.A.String()}, true
return RR{toLowerFQDN(t.Hdr.Name), "A", t.A.String(), ttl, expiry}, true
case *dns.AAAA:
return RR{toLowerFQDN(t.Hdr.Name), "AAAA", t.AAAA.String()}, true
return RR{toLowerFQDN(t.Hdr.Name), "AAAA", t.AAAA.String(), ttl, expiry}, true
case *dns.TXT:
return RR{toLowerFQDN(t.Hdr.Name), "TXT", strings.Join(t.Txt, "\t")}, true
return RR{toLowerFQDN(t.Hdr.Name), "TXT", strings.Join(t.Txt, "\t"), ttl, expiry}, true
default:
fields := strings.Fields(drr.String())
if len(fields) >= 4 {
return RR{toLowerFQDN(fields[0]), fields[3], strings.Join(fields[4:], "\t")}, true
return RR{toLowerFQDN(fields[0]), fields[3], strings.Join(fields[4:], "\t"), ttl, expiry}, true
}
}
return RR{}, false
}

// calculateExpiry calculates the expiry time of an RR.
func calculateExpiry(drr dns.RR) (time.Duration, time.Time) {
ttl := time.Second * time.Duration(drr.Header().Ttl)
expiry := time.Now().Add(ttl)
return ttl, expiry
}

0 comments on commit bfb9c6f

Please sign in to comment.