diff --git a/connmgr.go b/connmgr.go index aa42457..67ba284 100644 --- a/connmgr.go +++ b/connmgr.go @@ -19,6 +19,8 @@ var SilencePeriod = 10 * time.Second var log = logging.Logger("connmgr") +var connCloseStreamTimeout = 10 * time.Minute + // BasicConnMgr is a ConnManager that trims connections whenever the count exceeds the // high watermark. New connections are given a grace period before they're subject // to trimming. Trims are automatically run on demand, only if the time from the @@ -84,7 +86,7 @@ func (s *segment) tagInfoFor(p peer.ID) *peerInfo { temp: true, tags: make(map[string]int), decaying: make(map[*decayingTag]*connmgr.DecayingValue), - conns: make(map[network.Conn]time.Time), + conns: make(map[network.Conn]*connInfo), } s.peers[p] = pi return pi @@ -193,6 +195,12 @@ func (cm *BasicConnMgr) IsProtected(id peer.ID, tag string) (protected bool) { return protected } +type connInfo struct { + startTime time.Time + lastStreamOpen time.Time + nStreams int +} + // peerInfo stores metadata for a given peer. type peerInfo struct { id peer.ID @@ -202,7 +210,7 @@ type peerInfo struct { value int // cached sum of all tag values temp bool // this is a temporary entry holding early tags, and awaiting connections - conns map[network.Conn]time.Time // start time of each connection + conns map[network.Conn]*connInfo // start time and last stream open time of each connection. firstSeen time.Time // timestamp when we began tracking this peer. } @@ -313,10 +321,11 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn { return nil } + now := time.Now() npeers := cm.segments.countPeers() candidates := make([]*peerInfo, 0, npeers) ncandidates := 0 - gracePeriodStart := time.Now().Add(-cm.cfg.gracePeriod) + gracePeriodStart := now.Add(-cm.cfg.gracePeriod) cm.plk.RLock() for _, s := range cm.segments { @@ -359,9 +368,41 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn { target := ncandidates - cm.cfg.lowWater - // slightly overallocate because we may have more than one conns per peer - selected := make([]network.Conn, 0, target+10) + // overallocate because we may have more than one conns per peer + selected := make([]network.Conn, 0, 2*target) + seen := make(map[network.Conn]struct{}) + + // first select connections that are: + // i) older than 10 minutes but still haven't see a stream. + // ii) haven't seen a new stream since 10 minutes and have NO streams open. + for _, inf := range candidates { + if target <= 0 { + break + } + + // lock this to protect from concurrent modifications from connect/disconnect events + s := cm.segments.get(inf.id) + s.Lock() + + for c, info := range inf.conns { + // connections that are older than 10 minutes but still haven't see a stream. + if info.lastStreamOpen.IsZero() && now.Sub(info.startTime) > connCloseStreamTimeout { + selected = append(selected, c) + target-- + seen[c] = struct{}{} + } + + // connections that haven't seen a new stream since 10 minutes and have NO streams open. + if info.nStreams == 0 && !info.lastStreamOpen.IsZero() && now.Sub(info.lastStreamOpen) > connCloseStreamTimeout { + selected = append(selected, c) + target-- + seen[c] = struct{}{} + } + } + s.Unlock() + } + // now select remaining connections if we still haven't hit our target for _, inf := range candidates { if target <= 0 { break @@ -377,10 +418,12 @@ func (cm *BasicConnMgr) getConnsToClose() []network.Conn { delete(s.peers, inf.id) } else { for c := range inf.conns { - selected = append(selected, c) + if _, ok := seen[c]; !ok { + selected = append(selected, c) + target-- + } } } - target -= len(inf.conns) s.Unlock() } @@ -412,8 +455,8 @@ func (cm *BasicConnMgr) GetTagInfo(p peer.ID) *connmgr.TagInfo { for t, v := range pi.decaying { out.Tags[t.name] = v.Value } - for c, t := range pi.conns { - out.Conns[c.RemoteMultiaddr().String()] = t + for c, connInfo := range pi.conns { + out.Conns[c.RemoteMultiaddr().String()] = connInfo.startTime } return out @@ -528,7 +571,7 @@ func (nn *cmNotifee) Connected(n network.Network, c network.Conn) { firstSeen: time.Now(), tags: make(map[string]int), decaying: make(map[*decayingTag]*connmgr.DecayingValue), - conns: make(map[network.Conn]time.Time), + conns: make(map[network.Conn]*connInfo), } s.peers[id] = pinfo } else if pinfo.temp { @@ -545,7 +588,7 @@ func (nn *cmNotifee) Connected(n network.Network, c network.Conn) { return } - pinfo.conns[c] = time.Now() + pinfo.conns[c] = &connInfo{startTime: time.Now()} atomic.AddInt32(&cm.connCount, 1) } @@ -578,14 +621,61 @@ func (nn *cmNotifee) Disconnected(n network.Network, c network.Conn) { atomic.AddInt32(&cm.connCount, -1) } +// OpenedStream is called by notifiers to inform that a new libp2p stream has been opened on a connection. +// The notifee updates the BasicConnMgr accordingly to update the number of streams we have open on a connection. +// We then use this information when deciding which connections to trim. +func (nn *cmNotifee) OpenedStream(_ network.Network, stream network.Stream) { + cm := nn.cm() + + p := stream.Conn().RemotePeer() + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + cinf, ok := s.peers[p] + if !ok { + log.Error("received stream open notification for peer we are not tracking: ", p) + return + } + + c := stream.Conn() + connInfo, ok := cinf.conns[c] + if !ok { + log.Error("received stream open notification for conn we are not tracking: ", p) + return + } + + connInfo.lastStreamOpen = time.Now() + connInfo.nStreams++ +} + +// ClosedStream is called by notifiers to inform that an existing libp2p stream has been closed. +func (nn *cmNotifee) ClosedStream(_ network.Network, stream network.Stream) { + cm := nn.cm() + + p := stream.Conn().RemotePeer() + s := cm.segments.get(p) + s.Lock() + defer s.Unlock() + + cinf, ok := s.peers[p] + if !ok { + log.Error("received stream close notification for peer we are not tracking: ", p) + return + } + + c := stream.Conn() + connInfo, ok := cinf.conns[c] + if !ok { + log.Error("received stream close notification for conn we are not tracking: ", p) + return + } + + connInfo.nStreams-- +} + // Listen is no-op in this implementation. func (nn *cmNotifee) Listen(n network.Network, addr ma.Multiaddr) {} // ListenClose is no-op in this implementation. func (nn *cmNotifee) ListenClose(n network.Network, addr ma.Multiaddr) {} - -// OpenedStream is no-op in this implementation. -func (nn *cmNotifee) OpenedStream(network.Network, network.Stream) {} - -// ClosedStream is no-op in this implementation. -func (nn *cmNotifee) ClosedStream(network.Network, network.Stream) {} diff --git a/connmgr_test.go b/connmgr_test.go index 8e1b4b1..369cedf 100644 --- a/connmgr_test.go +++ b/connmgr_test.go @@ -7,6 +7,7 @@ import ( "time" detectrace "github.com/ipfs/go-detect-race" + "github.com/stretchr/testify/require" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -48,6 +49,19 @@ func randConn(t testing.TB, discNotify func(network.Network, network.Conn)) netw return &tconn{peer: pid, disconnectNotify: discNotify} } +type tStream struct { + network.Stream + conn network.Conn +} + +func (s *tStream) Conn() network.Conn { + return s.conn +} + +func randStream(t testing.TB, c network.Conn) network.Stream { + return &tStream{conn: c} +} + // Make sure multiple trim calls block. func TestTrimBlocks(t *testing.T) { cm := NewConnManager(200, 300, 0) @@ -124,6 +138,78 @@ func TestTrimJoin(t *testing.T) { wg.Wait() } +func TestCloseConnsWithNoStreams(t *testing.T) { + copy := connCloseStreamTimeout + connCloseStreamTimeout = 100 * time.Millisecond + defer func() { + connCloseStreamTimeout = copy + }() + + cm := NewConnManager(5, 8, 0) + not := cm.Notifee() + + var conns []network.Conn + for i := 0; i < 8; i++ { + rc := randConn(t, nil) + conns = append(conns, rc) + not.Connected(nil, rc) + } + + time.Sleep(1 * time.Second) + cm.TrimOpenConns(context.Background()) + + nClosed := 0 + // all conns are eligible for closing as they haven't seen a stream. + for _, c := range conns { + if c.(*tconn).closed { + nClosed++ + } + } + require.Equalf(t, 3, nClosed, "expected 3 closed connections, got %d", nClosed) +} + +func TestDontCloseConnsWithOpenStreams(t *testing.T) { + copy := connCloseStreamTimeout + connCloseStreamTimeout = 100 * time.Millisecond + defer func() { + connCloseStreamTimeout = copy + }() + + cm := NewConnManager(5, 8, 0) + not := cm.Notifee() + + var conns []network.Conn + for i := 0; i < 8; i++ { + rc := randConn(t, nil) + conns = append(conns, rc) + not.Connected(nil, rc) + } + + for i, c := range conns { + if i%3 == 0 { + not.OpenedStream(nil, randStream(t, c)) + } + } + + time.Sleep(1 * time.Second) + cm.TrimOpenConns(context.Background()) + + nClosed := 0 + for i, c := range conns { + if i%3 == 0 { + if c.(*tconn).closed { + t.Fatal("these should NOT be closed") + } + } else { + if c.(*tconn).closed { + nClosed++ + } + } + } + + require.Equalf(t, 3, nClosed, "expected 3 closed streams, got %d", nClosed) +} + func TestConnTrimming(t *testing.T) { cm := NewConnManager(200, 300, 0) not := cm.Notifee()