Skip to content

Commit 2ce249b

Browse files
delete entries from the cache when the TTL expires
1 parent 55022d7 commit 2ce249b

File tree

5 files changed

+99
-17
lines changed

5 files changed

+99
-17
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ See what needs to be done and submit a pull request :)
9494
* [x] Browse / Lookup / Register services
9595
* [x] Multiple IPv6 / IPv4 addresses support
9696
* [x] Send multiple probes (exp. back-off) if no service answers (*)
97-
* [ ] Timestamp entries for TTL checks
97+
* [x] Timestamp entries for TTL checks
9898
* [ ] Compare new multicasts with already received services
9999

100100
_Notes:_

client.go

+44-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net"
77
"strings"
8+
"sync"
89
"time"
910

1011
"github.com/cenkalti/backoff"
@@ -143,6 +144,9 @@ type client struct {
143144
ipv4conn *ipv4.PacketConn
144145
ipv6conn *ipv6.PacketConn
145146
ifaces []net.Interface
147+
148+
mutex sync.Mutex
149+
sentEntries map[string]*ServiceEntry
146150
}
147151

148152
// Client structure constructor
@@ -177,6 +181,28 @@ func newClient(opts clientOpts) (*client, error) {
177181
}, nil
178182
}
179183

184+
var cleanupFreq = 10 * time.Second
185+
186+
// clean up entries whose TTL expired
187+
func (c *client) cleanupSentEntries(ctx context.Context) {
188+
ticker := time.NewTicker(cleanupFreq)
189+
defer ticker.Stop()
190+
for {
191+
select {
192+
case t := <-ticker.C:
193+
c.mutex.Lock()
194+
for k, e := range c.sentEntries {
195+
if t.After(e.Expiry) {
196+
delete(c.sentEntries, k)
197+
}
198+
}
199+
c.mutex.Unlock()
200+
case <-ctx.Done():
201+
return
202+
}
203+
}
204+
}
205+
180206
// Start listeners and waits for the shutdown signal from exit channel
181207
func (c *client) mainloop(ctx context.Context, params *lookupParams) {
182208
// start listening for responses
@@ -189,16 +215,20 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
189215
}
190216

191217
// Iterate through channels from listeners goroutines
192-
var entries, sentEntries map[string]*ServiceEntry
193-
sentEntries = make(map[string]*ServiceEntry)
218+
var entries map[string]*ServiceEntry
219+
c.sentEntries = make(map[string]*ServiceEntry)
220+
go c.cleanupSentEntries(ctx)
221+
194222
for {
223+
var now time.Time
195224
select {
196225
case <-ctx.Done():
197226
// Context expired. Notify subscriber that we are done here.
198227
params.done()
199228
c.shutdown()
200229
return
201230
case msg := <-msgCh:
231+
now = time.Now()
202232
entries = make(map[string]*ServiceEntry)
203233
sections := append(msg.Answer, msg.Ns...)
204234
sections = append(sections, msg.Extra...)
@@ -218,7 +248,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
218248
params.Service,
219249
params.Domain)
220250
}
221-
entries[rr.Ptr].TTL = rr.Hdr.Ttl
251+
entries[rr.Ptr].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
222252
case *dns.SRV:
223253
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
224254
continue
@@ -233,7 +263,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
233263
}
234264
entries[rr.Hdr.Name].HostName = rr.Target
235265
entries[rr.Hdr.Name].Port = int(rr.Port)
236-
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
266+
entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
237267
case *dns.TXT:
238268
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
239269
continue
@@ -247,7 +277,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
247277
params.Domain)
248278
}
249279
entries[rr.Hdr.Name].Text = rr.Txt
250-
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
280+
entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second)
251281
}
252282
}
253283
// Associate IPs in a second round as other fields should be filled by now.
@@ -271,12 +301,15 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
271301

272302
if len(entries) > 0 {
273303
for k, e := range entries {
274-
if e.TTL == 0 {
304+
c.mutex.Lock()
305+
if !e.Expiry.After(now) {
275306
delete(entries, k)
276-
delete(sentEntries, k)
307+
delete(c.sentEntries, k)
308+
c.mutex.Unlock()
277309
continue
278310
}
279-
if _, ok := sentEntries[k]; ok {
311+
if _, ok := c.sentEntries[k]; ok {
312+
c.mutex.Unlock()
280313
continue
281314
}
282315

@@ -286,14 +319,16 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) {
286319
// Require at least one resolved IP address for ServiceEntry
287320
// TODO: wait some more time as chances are high both will arrive.
288321
if len(e.AddrIPv4) == 0 && len(e.AddrIPv6) == 0 {
322+
c.mutex.Unlock()
289323
continue
290324
}
291325
}
292326
// Submit entry to subscriber and cache it.
293327
// This is also a point to possibly stop probing actively for a
294328
// service entry.
329+
c.sentEntries[k] = e
330+
c.mutex.Unlock()
295331
params.Entries <- e
296-
sentEntries[k] = e
297332
if !params.isBrowsing {
298333
params.disableProbing()
299334
}

server.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ const (
2121
multicastRepetitions = 2
2222
)
2323

24+
var defaultTTL uint32 = 3200
25+
2426
// Register a service by given arguments. This call will take the system's hostname
2527
// and lookup IP by that hostname.
2628
func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) {
@@ -173,7 +175,7 @@ func newServer(ifaces []net.Interface) (*Server, error) {
173175
ipv4conn: ipv4conn,
174176
ipv6conn: ipv6conn,
175177
ifaces: ifaces,
176-
ttl: 3200,
178+
ttl: defaultTTL,
177179
shouldShutdown: make(chan struct{}),
178180
}
179181

service.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"net"
66
"sync"
7+
"time"
78
)
89

910
// ServiceRecord contains the basic description of a service, which contains instance name, service type & domain
@@ -103,12 +104,12 @@ func (l *lookupParams) disableProbing() {
103104
// used to answer multicast queries.
104105
type ServiceEntry struct {
105106
ServiceRecord
106-
HostName string `json:"hostname"` // Host machine DNS name
107-
Port int `json:"port"` // Service Port
108-
Text []string `json:"text"` // Service info served as a TXT record
109-
TTL uint32 `json:"ttl"` // TTL of the service record
110-
AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address
111-
AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address
107+
HostName string `json:"hostname"` // Host machine DNS name
108+
Port int `json:"port"` // Service Port
109+
Text []string `json:"text"` // Service info served as a TXT record
110+
Expiry time.Time `json:"expiry"` // Expiry of the service entry, will be converted to a TTL value
111+
AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address
112+
AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address
112113
}
113114

114115
// NewServiceEntry constructs a ServiceEntry.

service_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,48 @@ func TestSubtype(t *testing.T) {
184184
t.Fatalf("Expected port is %d, but got %d", mdnsPort, expectedResult[0].Port)
185185
}
186186
})
187+
188+
t.Run("ttl", func(t *testing.T) {
189+
origTTL := defaultTTL
190+
origCleanupFreq := cleanupFreq
191+
defer func() {
192+
defaultTTL = origTTL
193+
cleanupFreq = origCleanupFreq
194+
}()
195+
defaultTTL = 2 // 2 seconds
196+
cleanupFreq = 100 * time.Millisecond
197+
198+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
199+
defer cancel()
200+
go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain)
201+
202+
entries := make(chan *ServiceEntry)
203+
var expectedResult []*ServiceEntry
204+
go func() {
205+
for {
206+
select {
207+
case s := <-entries:
208+
expectedResult = append(expectedResult, s)
209+
case <-ctx.Done():
210+
return
211+
}
212+
}
213+
}()
214+
215+
resolver, err := NewResolver(nil)
216+
if err != nil {
217+
t.Fatalf("Expected create resolver success, but got %v", err)
218+
}
219+
if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil {
220+
t.Fatalf("Expected browse success, but got %v", err)
221+
}
222+
223+
<-ctx.Done()
224+
if len(expectedResult) != 2 {
225+
t.Fatalf("Expected to have received 2 entries, but got %d", len(expectedResult))
226+
}
227+
if expectedResult[0].ServiceInstanceName() != expectedResult[1].ServiceInstanceName() {
228+
t.Fatalf("expected the two entries to be identical")
229+
}
230+
})
187231
}

0 commit comments

Comments
 (0)