Skip to content

Commit 2542e8a

Browse files
committed
feat(p2p): Add a maximum number of connections per IP address
1 parent e2d9278 commit 2542e8a

File tree

5 files changed

+68
-7
lines changed

5 files changed

+68
-7
lines changed

hathor/p2p/factory.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def buildProtocol(self, addr: IAddress) -> MyServerProtocol:
5757
p2p_manager=self.p2p_manager,
5858
use_ssl=self.use_ssl,
5959
inbound=True,
60+
remote_address=addr,
6061
)
6162
p.factory = self
6263
return p
@@ -90,6 +91,7 @@ def buildProtocol(self, addr: IAddress) -> MyClientProtocol:
9091
p2p_manager=self.p2p_manager,
9192
use_ssl=self.use_ssl,
9293
inbound=False,
94+
remote_address=addr,
9395
)
9496
p.factory = self
9597
return p

hathor/p2p/manager.py

+14
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self,
125125

126126
# Global maximum number of connections.
127127
self.max_connections: int = settings.PEER_MAX_CONNECTIONS
128+
self.max_connections_per_ip: int = 16
128129

129130
# Global rate limiter for all connections.
130131
self.rate_limiter = RateLimiter(self.reactor)
@@ -314,6 +315,19 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None:
314315
self.log.warn('reached maximum number of connections', max_connections=self.max_connections)
315316
protocol.disconnect(force=True)
316317
return
318+
319+
ip_address = protocol.get_remote_ip_address()
320+
if ip_address:
321+
count = len([1 for conn in self.connections if conn.get_remote_ip_address() == ip_address])
322+
if count >= self.max_connections_per_ip:
323+
self.log.warn(
324+
'reached maximum number of connections per ip address',
325+
ip_address=ip_address,
326+
max_connections_per_ip=self.max_connections_per_ip,
327+
)
328+
protocol.disconnect(force=True)
329+
return
330+
317331
self.connections.add(protocol)
318332
self.handshaking_peers.add(protocol)
319333

hathor/p2p/protocol.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from structlog import get_logger
2020
from twisted.internet.defer import Deferred
21-
from twisted.internet.interfaces import IDelayedCall, ITCPTransport, ITransport
21+
from twisted.internet.interfaces import IAddress, IDelayedCall, ITCPTransport, ITransport
2222
from twisted.internet.protocol import connectionDone
2323
from twisted.protocols.basic import LineReceiver
2424
from twisted.python.failure import Failure
@@ -92,11 +92,12 @@ class WarningFlags(str, Enum):
9292
capabilities: set[str] # capabilities received from the peer in HelloState
9393

9494
def __init__(self, network: str, my_peer: PeerId, p2p_manager: 'ConnectionsManager',
95-
*, use_ssl: bool, inbound: bool) -> None:
95+
*, use_ssl: bool, inbound: bool, remote_address: IAddress) -> None:
9696
self._settings = get_settings()
9797
self.network = network
9898
self.my_peer = my_peer
9999
self.connections = p2p_manager
100+
self.remote_address = remote_address
100101

101102
assert p2p_manager.manager is not None
102103
self.node = p2p_manager.manager
@@ -181,8 +182,11 @@ def is_state(self, state_enum: PeerState) -> bool:
181182

182183
def get_short_remote(self) -> str:
183184
"""Get remote for logging."""
184-
assert self.transport is not None
185-
return format_address(self.transport.getPeer())
185+
return format_address(self.remote_address)
186+
187+
def get_remote_ip_address(self) -> Optional[str]:
188+
"""Return remote address (ipv4 or ipv6)."""
189+
return getattr(self.remote_address, 'host', None)
186190

187191
def get_peer_id(self) -> Optional[str]:
188192
"""Get peer id for logging."""
@@ -230,6 +234,8 @@ def on_connect(self) -> None:
230234
""" Executed when the connection is established.
231235
"""
232236
assert not self.aborting
237+
assert self.transport is not None
238+
assert self.remote_address == self.transport.getPeer()
233239
self.update_log_context()
234240
self.log.debug('new connection')
235241

hathor/simulator/fake_connection.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from OpenSSL.crypto import X509
1919
from structlog import get_logger
2020
from twisted.internet.address import HostnameAddress
21+
from twisted.internet.interfaces import IAddress
2122
from twisted.internet.testing import StringTransport
2223

2324
if TYPE_CHECKING:
@@ -39,7 +40,8 @@ def getPeerCertificate(self) -> X509:
3940

4041
class FakeConnection:
4142
def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, latency: float = 0,
42-
autoreconnect: bool = False):
43+
autoreconnect: bool = False, address1: Optional[IAddress] = None,
44+
address2: Optional[IAddress] = None):
4345
"""
4446
:param: latency: Latency between nodes in seconds
4547
"""
@@ -56,6 +58,9 @@ def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, late
5658
self._buf1: deque[str] = deque()
5759
self._buf2: deque[str] = deque()
5860

61+
self._address1: Optional[IAddress] = address1
62+
self._address2: Optional[IAddress] = address2
63+
5964
self.reconnect()
6065

6166
@property
@@ -140,6 +145,10 @@ def can_step(self) -> bool:
140145
return False
141146

142147
def run_one_step(self, debug=False, force=False):
148+
if self.tr1.disconnected:
149+
return
150+
if self.tr2.disconnected:
151+
return
143152
assert self.is_connected, 'not connected'
144153

145154
if debug:
@@ -218,8 +227,12 @@ def reconnect(self) -> None:
218227
self.disconnect(Failure(Exception('forced reconnection')))
219228
self._buf1.clear()
220229
self._buf2.clear()
221-
self._proto1 = self.manager1.connections.server_factory.buildProtocol(HostnameAddress(b'fake', 0))
222-
self._proto2 = self.manager2.connections.client_factory.buildProtocol(HostnameAddress(b'fake', 0))
230+
231+
address1 = self._address1 or HostnameAddress(b'fake', 0)
232+
address2 = self._address2 or HostnameAddress(b'fake', 0)
233+
234+
self._proto1 = self.manager1.connections.server_factory.buildProtocol(address2)
235+
self._proto2 = self.manager2.connections.client_factory.buildProtocol(address1)
223236
self.tr1 = HathorStringTransport(self._proto2.my_peer)
224237
self.tr2 = HathorStringTransport(self._proto1.my_peer)
225238
self._proto1.makeConnection(self.tr1)

tests/p2p/test_max_conn_per_ip.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from twisted.internet.address import IPv4Address
2+
3+
from hathor.simulator import FakeConnection
4+
from tests.simulation.base import SimulatorTestCase
5+
6+
7+
class PeerRelayTestCase(SimulatorTestCase):
8+
__test__ = True
9+
10+
def test_max_conn_per_ip(self) -> None:
11+
m0 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)
12+
13+
max_connections_per_ip = m0.connections.max_connections_per_ip
14+
for i in range(1, max_connections_per_ip + 8):
15+
m1 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)
16+
17+
address = IPv4Address('TCP', '127.0.0.1', 1234 + i)
18+
conn = FakeConnection(m0, m1, latency=0.05, address2=address)
19+
self.simulator.add_connection(conn)
20+
21+
self.simulator.run(10)
22+
23+
if i <= max_connections_per_ip:
24+
self.assertFalse(conn.tr1.disconnected)
25+
else:
26+
self.assertTrue(conn.tr1.disconnected)

0 commit comments

Comments
 (0)