Skip to content

Commit 941d9fb

Browse files
committed
feat(p2p): Add a maximum number of connections per IP address
1 parent b2e670f commit 941d9fb

File tree

5 files changed

+66
-7
lines changed

5 files changed

+66
-7
lines changed

hathor/p2p/factory.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def buildProtocol(self, addr: IAddress) -> MyServerProtocol:
5757
p2p_manager=self.p2p_manager,
5858
use_ssl=self.use_ssl,
5959
inbound=True,
60+
remote_address=addr,
6061
)
6162
p.factory = self
6263
return p
@@ -90,6 +91,7 @@ def buildProtocol(self, addr: IAddress) -> MyClientProtocol:
9091
p2p_manager=self.p2p_manager,
9192
use_ssl=self.use_ssl,
9293
inbound=False,
94+
remote_address=addr,
9395
)
9496
p.factory = self
9597
return p

hathor/p2p/manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self,
125125

126126
# Global maximum number of connections.
127127
self.max_connections: int = settings.PEER_MAX_CONNECTIONS
128+
self.max_connections_per_ip: int = 16
128129

129130
# Global rate limiter for all connections.
130131
self.rate_limiter = RateLimiter(self.reactor)
@@ -314,6 +315,18 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None:
314315
self.log.warn('reached maximum number of connections', max_connections=self.max_connections)
315316
protocol.disconnect(force=True)
316317
return
318+
319+
ip_address = protocol.get_remote_ip_address()
320+
if ip_address:
321+
count = len([1 for conn in self.connections if conn.get_remote_ip_address() == ip_address])
322+
if count >= self.max_connections_per_ip:
323+
self.log.warn(
324+
'reached maximum number of connections per ip address',
325+
max_connections_per_ip=self.max_connections_per_ip
326+
)
327+
protocol.disconnect(force=True)
328+
return
329+
317330
self.connections.add(protocol)
318331
self.handshaking_peers.add(protocol)
319332

hathor/p2p/protocol.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from structlog import get_logger
2020
from twisted.internet.defer import Deferred
21-
from twisted.internet.interfaces import IDelayedCall, ITCPTransport, ITransport
21+
from twisted.internet.interfaces import IAddress, IDelayedCall, ITCPTransport, ITransport
2222
from twisted.internet.protocol import connectionDone
2323
from twisted.protocols.basic import LineReceiver
2424
from twisted.python.failure import Failure
@@ -92,11 +92,12 @@ class WarningFlags(str, Enum):
9292
capabilities: set[str] # capabilities received from the peer in HelloState
9393

9494
def __init__(self, network: str, my_peer: PeerId, p2p_manager: 'ConnectionsManager',
95-
*, use_ssl: bool, inbound: bool) -> None:
95+
*, use_ssl: bool, inbound: bool, remote_address: 'IAddress') -> None:
9696
self._settings = get_settings()
9797
self.network = network
9898
self.my_peer = my_peer
9999
self.connections = p2p_manager
100+
self.remote_address = remote_address
100101

101102
assert p2p_manager.manager is not None
102103
self.node = p2p_manager.manager
@@ -181,8 +182,11 @@ def is_state(self, state_enum: PeerState) -> bool:
181182

182183
def get_short_remote(self) -> str:
183184
"""Get remote for logging."""
184-
assert self.transport is not None
185-
return format_address(self.transport.getPeer())
185+
return format_address(self.remote_address)
186+
187+
def get_remote_ip_address(self) -> Optional[str]:
188+
"""Return remote address (ipv4 or ipv6)."""
189+
return getattr(self.remote_address, 'host', None)
186190

187191
def get_peer_id(self) -> Optional[str]:
188192
"""Get peer id for logging."""

hathor/simulator/fake_connection.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from twisted.internet.testing import StringTransport
2222

2323
if TYPE_CHECKING:
24+
from twisted.internet.interfaces import IAddress
25+
2426
from hathor.manager import HathorManager
2527
from hathor.p2p.peer_id import PeerId
2628

@@ -39,7 +41,8 @@ def getPeerCertificate(self) -> X509:
3941

4042
class FakeConnection:
4143
def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, latency: float = 0,
42-
autoreconnect: bool = False):
44+
autoreconnect: bool = False, address1: Optional['IAddress'] = None,
45+
address2: Optional['IAddress'] = None):
4346
"""
4447
:param: latency: Latency between nodes in seconds
4548
"""
@@ -56,6 +59,9 @@ def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, late
5659
self._buf1: deque[str] = deque()
5760
self._buf2: deque[str] = deque()
5861

62+
self._address1: Optional['IAddress'] = address1
63+
self._address2: Optional['IAddress'] = address2
64+
5965
self.reconnect()
6066

6167
@property
@@ -140,6 +146,10 @@ def can_step(self) -> bool:
140146
return False
141147

142148
def run_one_step(self, debug=False, force=False):
149+
if self.tr1.disconnected:
150+
return
151+
if self.tr2.disconnected:
152+
return
143153
assert self.is_connected, 'not connected'
144154

145155
if debug:
@@ -218,8 +228,12 @@ def reconnect(self) -> None:
218228
self.disconnect(Failure(Exception('forced reconnection')))
219229
self._buf1.clear()
220230
self._buf2.clear()
221-
self._proto1 = self.manager1.connections.server_factory.buildProtocol(HostnameAddress(b'fake', 0))
222-
self._proto2 = self.manager2.connections.client_factory.buildProtocol(HostnameAddress(b'fake', 0))
231+
232+
address1 = self._address1 or HostnameAddress(b'fake', 0)
233+
address2 = self._address2 or HostnameAddress(b'fake', 0)
234+
235+
self._proto1 = self.manager1.connections.server_factory.buildProtocol(address2)
236+
self._proto2 = self.manager2.connections.client_factory.buildProtocol(address1)
223237
self.tr1 = HathorStringTransport(self._proto2.my_peer)
224238
self.tr2 = HathorStringTransport(self._proto1.my_peer)
225239
self._proto1.makeConnection(self.tr1)

tests/p2p/test_max_conn_per_ip.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from twisted.internet.address import IPv4Address
2+
3+
from hathor.simulator import FakeConnection
4+
from tests.simulation.base import SimulatorTestCase
5+
6+
7+
class PeerRelayTestCase(SimulatorTestCase):
8+
__test__ = True
9+
10+
def test_max_conn_per_ip(self) -> None:
11+
m0 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)
12+
13+
max_connections_per_ip = m0.connections.max_connections_per_ip
14+
for i in range(1, max_connections_per_ip + 8):
15+
m1 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)
16+
17+
address = IPv4Address('TCP', '127.0.0.1', 1234 + i)
18+
conn = FakeConnection(m0, m1, latency=0.05, address2=address)
19+
self.simulator.add_connection(conn)
20+
21+
self.simulator.run(10)
22+
23+
if i <= max_connections_per_ip:
24+
self.assertFalse(conn.tr1.disconnected)
25+
else:
26+
self.assertTrue(conn.tr1.disconnected)

0 commit comments

Comments
 (0)