diff --git a/README.md b/README.md index ab621d0..836dfec 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,8 @@ func main() { } ``` +Or construct with `dnsr.NewExpiring()` to expire cache entries based on TTL. + [Documentation](https://godoc.org/github.com/domainr/dnsr) ## Development diff --git a/cache.go b/cache.go index 6154b80..36f6e1f 100644 --- a/cache.go +++ b/cache.go @@ -1,9 +1,13 @@ package dnsr -import "sync" +import ( + "sync" + "time" +) type cache struct { capacity int + expire bool m sync.RWMutex entries map[string]entry } @@ -14,13 +18,14 @@ const MinCacheCapacity = 1000 // newCache initializes and returns a new cache instance. // Cache capacity defaults to MinCacheCapacity if <= 0. -func newCache(capacity int) *cache { +func newCache(capacity int, expire bool) *cache { if capacity <= 0 { capacity = MinCacheCapacity } return &cache{ capacity: capacity, entries: make(map[string]entry), + expire: expire, } } @@ -91,11 +96,26 @@ func (c *cache) get(qname string) RRs { if len(e) == 0 { return emptyRRs } - i := 0 - rrs := make(RRs, len(e)) - for rr, _ := range e { - rrs[i] = rr - i++ + if c.expire { + i := 0 + rrs := make(RRs, len(e)) + now := time.Now() + for rr, _ := range e { + if !rr.Expiry.IsZero() && now.After(rr.Expiry) { + delete(e, rr) + } else { + rrs[i] = rr + i++ + } + } + return rrs[:i] + } else { + i := 0 + rrs := make(RRs, len(e)) + for rr, _ := range e { + rrs[i] = rr + i++ + } + return rrs } - return rrs } diff --git a/cache_test.go b/cache_test.go index 08adb49..f5d39c5 100644 --- a/cache_test.go +++ b/cache_test.go @@ -2,15 +2,36 @@ package dnsr import ( "testing" + "time" "github.com/nbio/st" ) func TestCache(t *testing.T) { - c := newCache(100) + c := newCache(100, false) c.addNX("hello.") rr := RR{Name: "hello.", Type: "A", Value: "1.2.3.4"} c.add("hello.", rr) rrs := c.get("hello.") st.Expect(t, len(rrs), 1) } + +func TestLiveCacheEntry(t *testing.T) { + c := newCache(100, true) + c.addNX("alive.") + alive := time.Now().Add(time.Minute) + rr := RR{Name: "alive.", Type: "A", Value: "1.2.3.4", Expiry: alive} + c.add("alive.", rr) + rrs := c.get("alive.") + st.Expect(t, len(rrs), 1) +} + +func TestExpiredCacheEntry(t *testing.T) { + c := newCache(100, true) + c.addNX("expired.") + expired := time.Now().Add(-time.Minute) + rr := RR{Name: "expired.", Type: "A", Value: "1.2.3.4", Expiry: expired} + c.add("expired.", rr) + rrs := c.get("expired.") + st.Expect(t, len(rrs), 0) +} diff --git a/resolver.go b/resolver.go index 655c417..ef58412 100644 --- a/resolver.go +++ b/resolver.go @@ -32,6 +32,7 @@ var ( // Resolver implements a primitive, non-recursive, caching DNS resolver. type Resolver struct { cache *cache + expire bool timeout time.Duration } @@ -43,7 +44,23 @@ func New(capacity int) *Resolver { // NewWithTimeout initializes a Resolver with the specified cache size and resolution timeout. func NewWithTimeout(capacity int, timeout time.Duration) *Resolver { r := &Resolver{ - cache: newCache(capacity), + cache: newCache(capacity, false), + expire: false, + timeout: timeout, + } + return r +} + +// NewExpiring initializes an expiring Resolver with the specified cache size. +func NewExpiring(capacity int) *Resolver { + return NewExpiringWithTimeout(capacity, Timeout) +} + +// NewExpiringWithTimeout initializes an expiring Resolved with the specified cache size and resolution timeout. +func NewExpiringWithTimeout(capacity int, timeout time.Duration) *Resolver { + r := &Resolver{ + cache: newCache(capacity, true), + expire: true, timeout: timeout, } return r @@ -243,7 +260,7 @@ func (r *Resolver) exchange(ctx context.Context, host, qname, qtype string, dept var hasSOA bool if qtype == "NS" { for _, drr := range rmsg.Ns { - rr, ok := convertRR(drr) + rr, ok := convertRR(drr, r.expire) if !ok { continue } @@ -293,7 +310,7 @@ func (r *Resolver) saveDNSRR(host, qname string, drrs []dns.RR) RRs { var rrs RRs cl := dns.CountLabel(qname) for _, drr := range drrs { - rr, ok := convertRR(drr) + rr, ok := convertRR(drr, r.expire) if !ok { continue } diff --git a/resolver_test.go b/resolver_test.go index 9bbde24..8884e52 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -187,6 +187,15 @@ func TestBazCoUKAny(t *testing.T) { st.Expect(t, count(rrs, func(rr RR) bool { return rr.Type == "NS" }) >= 2, true) } +func TestTTL(t *testing.T) { + r := NewExpiring(0) + rrs, err := r.ResolveErr("google.com", "A") + st.Expect(t, err, nil) + st.Expect(t, len(rrs) >= 4, true) + rr := rrs[0] + st.Expect(t, !rr.Expiry.IsZero(), true) +} + var testResolver *Resolver func BenchmarkResolve(b *testing.B) { diff --git a/root_cache.go b/root_cache.go index 6dd956b..25a3cee 100644 --- a/root_cache.go +++ b/root_cache.go @@ -13,12 +13,12 @@ var ( ) func init() { - rootCache = newCache(strings.Count(root, "\n")) + rootCache = newCache(strings.Count(root, "\n"), false) for t := range dns.ParseZone(strings.NewReader(root), "", "") { if t.Error != nil { continue } - rr, ok := convertRR(t.RR) + rr, ok := convertRR(t.RR, false) if ok { rootCache.add(rr.Name, rr) } diff --git a/rr.go b/rr.go index 51c7c37..9e57fdb 100644 --- a/rr.go +++ b/rr.go @@ -1,16 +1,20 @@ package dnsr import ( + "fmt" "strings" + "time" "github.com/miekg/dns" ) // RR represents a DNS resource record. type RR struct { - Name string - Type string - Value string + Name string + Type string + Value string + TTL time.Duration + Expiry time.Time } // RRs represents a slice of DNS resource records. @@ -30,32 +34,55 @@ const NameCollision = "127.0.53.53" // String returns a string representation of an RR in zone-file format. func (rr *RR) String() string { - return rr.Name + "\t 3600\tIN\t" + rr.Type + "\t" + rr.Value + if rr.Expiry.IsZero() { + return rr.Name + "\t 3600\tIN\t" + rr.Type + "\t" + rr.Value + } else { + ttl := ttlString(rr.TTL) + return rr.Name + "\t" + ttl + "\t" + rr.Type + "\t" + rr.Value + } +} + +// ttlString constructs the TTL field of an RR string. +func ttlString(ttl time.Duration) string { + seconds := int(ttl.Seconds()) + return fmt.Sprintf("%10d", seconds) } // convertRR converts a dns.RR to an RR. // If the RR is not a type that this package uses, // It will attempt to translate this if there are enough parameters // Should all translation fail, it returns an undefined RR and false. -func convertRR(drr dns.RR) (RR, bool) { +func convertRR(drr dns.RR, expire bool) (RR, bool) { + var ttl time.Duration + var expiry time.Time + if expire { + ttl, expiry = calculateExpiry(drr) + } switch t := drr.(type) { case *dns.SOA: - return RR{toLowerFQDN(t.Hdr.Name), "SOA", toLowerFQDN(t.Ns)}, true + return RR{toLowerFQDN(t.Hdr.Name), "SOA", toLowerFQDN(t.Ns), ttl, expiry}, true case *dns.NS: - return RR{toLowerFQDN(t.Hdr.Name), "NS", toLowerFQDN(t.Ns)}, true + return RR{toLowerFQDN(t.Hdr.Name), "NS", toLowerFQDN(t.Ns), ttl, expiry}, true case *dns.CNAME: - return RR{toLowerFQDN(t.Hdr.Name), "CNAME", toLowerFQDN(t.Target)}, true + return RR{toLowerFQDN(t.Hdr.Name), "CNAME", toLowerFQDN(t.Target), ttl, expiry}, true case *dns.A: - return RR{toLowerFQDN(t.Hdr.Name), "A", t.A.String()}, true + return RR{toLowerFQDN(t.Hdr.Name), "A", t.A.String(), ttl, expiry}, true case *dns.AAAA: - return RR{toLowerFQDN(t.Hdr.Name), "AAAA", t.AAAA.String()}, true + return RR{toLowerFQDN(t.Hdr.Name), "AAAA", t.AAAA.String(), ttl, expiry}, true case *dns.TXT: - return RR{toLowerFQDN(t.Hdr.Name), "TXT", strings.Join(t.Txt, "\t")}, true + return RR{toLowerFQDN(t.Hdr.Name), "TXT", strings.Join(t.Txt, "\t"), ttl, expiry}, true default: fields := strings.Fields(drr.String()) if len(fields) >= 4 { - return RR{toLowerFQDN(fields[0]), fields[3], strings.Join(fields[4:], "\t")}, true + return RR{toLowerFQDN(fields[0]), fields[3], strings.Join(fields[4:], "\t"), ttl, expiry}, true } } return RR{}, false } + +// calculateExpiry calculates the expiry time of an RR. +func calculateExpiry(drr dns.RR) (time.Duration, time.Time) { + ttl := time.Second * time.Duration(drr.Header().Ttl) + expiry := time.Now().Add(ttl) + return ttl, expiry +}