Skip to content

feat(p2p): Add a maximum number of connections per IP address #781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hathor/p2p/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def buildProtocol(self, addr: IAddress) -> MyServerProtocol:
p2p_manager=self.p2p_manager,
use_ssl=self.use_ssl,
inbound=True,
remote_address=addr,
)
p.factory = self
return p
Expand Down Expand Up @@ -90,6 +91,7 @@ def buildProtocol(self, addr: IAddress) -> MyClientProtocol:
p2p_manager=self.p2p_manager,
use_ssl=self.use_ssl,
inbound=False,
remote_address=addr,
)
p.factory = self
return p
14 changes: 14 additions & 0 deletions hathor/p2p/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self,

# Global maximum number of connections.
self.max_connections: int = settings.PEER_MAX_CONNECTIONS
self.max_connections_per_ip: int = 16

# Global rate limiter for all connections.
self.rate_limiter = RateLimiter(self.reactor)
Expand Down Expand Up @@ -350,6 +351,19 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None:
self.log.warn('reached maximum number of connections', max_connections=self.max_connections)
protocol.disconnect(force=True)
return

ip_address = protocol.get_remote_ip_address()
if ip_address:
count = len([1 for conn in self.connections if conn.get_remote_ip_address() == ip_address])
if count >= self.max_connections_per_ip:
self.log.warn(
'reached maximum number of connections per ip address',
ip_address=ip_address,
max_connections_per_ip=self.max_connections_per_ip,
)
protocol.disconnect(force=True)
return

self.connections.add(protocol)
self.handshaking_peers.add(protocol)

Expand Down
14 changes: 10 additions & 4 deletions hathor/p2p/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from structlog import get_logger
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, ITCPTransport, ITransport
from twisted.internet.interfaces import IAddress, IDelayedCall, ITCPTransport, ITransport
from twisted.internet.protocol import connectionDone
from twisted.protocols.basic import LineReceiver
from twisted.python.failure import Failure
Expand Down Expand Up @@ -92,11 +92,12 @@ class WarningFlags(str, Enum):
capabilities: set[str] # capabilities received from the peer in HelloState

def __init__(self, network: str, my_peer: PeerId, p2p_manager: 'ConnectionsManager',
*, use_ssl: bool, inbound: bool) -> None:
*, use_ssl: bool, inbound: bool, remote_address: IAddress) -> None:
self._settings = get_settings()
self.network = network
self.my_peer = my_peer
self.connections = p2p_manager
self.remote_address = remote_address

assert p2p_manager.manager is not None
self.node = p2p_manager.manager
Expand Down Expand Up @@ -181,8 +182,11 @@ def is_state(self, state_enum: PeerState) -> bool:

def get_short_remote(self) -> str:
"""Get remote for logging."""
assert self.transport is not None
return format_address(self.transport.getPeer())
return format_address(self.remote_address)

def get_remote_ip_address(self) -> Optional[str]:
"""Return remote address (ipv4 or ipv6)."""
return getattr(self.remote_address, 'host', None)

def get_peer_id(self) -> Optional[str]:
"""Get peer id for logging."""
Expand Down Expand Up @@ -230,6 +234,8 @@ def on_connect(self) -> None:
""" Executed when the connection is established.
"""
assert not self.aborting
assert self.transport is not None
assert self.remote_address == self.transport.getPeer()
self.update_log_context()
self.log.debug('new connection')

Expand Down
29 changes: 21 additions & 8 deletions hathor/simulator/fake_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from OpenSSL.crypto import X509
from structlog import get_logger
from twisted.internet.address import HostnameAddress
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IAddress
from twisted.internet.testing import StringTransport

if TYPE_CHECKING:
Expand All @@ -28,8 +29,8 @@


class HathorStringTransport(StringTransport):
def __init__(self, peer: 'PeerId'):
super().__init__()
def __init__(self, peer: 'PeerId', hostAddress: IAddress, peerAddress: IAddress) -> None:
super().__init__(hostAddress=hostAddress, peerAddress=peerAddress)
self.peer = peer

def getPeerCertificate(self) -> X509:
Expand All @@ -39,7 +40,8 @@ def getPeerCertificate(self) -> X509:

class FakeConnection:
def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, latency: float = 0,
autoreconnect: bool = False):
autoreconnect: bool = False, address1: Optional[IAddress] = None,
address2: Optional[IAddress] = None):
"""
:param: latency: Latency between nodes in seconds
"""
Expand All @@ -56,6 +58,9 @@ def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, late
self._buf1: deque[str] = deque()
self._buf2: deque[str] = deque()

self._address1: Optional[IAddress] = address1
self._address2: Optional[IAddress] = address2

self.reconnect()

@property
Expand Down Expand Up @@ -148,6 +153,10 @@ def can_step(self) -> bool:
return False

def run_one_step(self, debug=False, force=False):
if self.tr1.disconnected:
return
if self.tr2.disconnected:
return
assert self.is_connected, 'not connected'

if debug:
Expand Down Expand Up @@ -226,10 +235,14 @@ def reconnect(self) -> None:
self.disconnect(Failure(Exception('forced reconnection')))
self._buf1.clear()
self._buf2.clear()
self._proto1 = self.manager1.connections.server_factory.buildProtocol(HostnameAddress(b'fake', 0))
self._proto2 = self.manager2.connections.client_factory.buildProtocol(HostnameAddress(b'fake', 0))
self.tr1 = HathorStringTransport(self._proto2.my_peer)
self.tr2 = HathorStringTransport(self._proto1.my_peer)

address1 = self._address1 or IPv4Address('TCP', '192.168.0.14', 1234)
address2 = self._address2 or IPv4Address('TCP', '192.168.0.72', 5432)

self._proto1 = self.manager1.connections.server_factory.buildProtocol(address2)
self._proto2 = self.manager2.connections.client_factory.buildProtocol(address1)
self.tr1 = HathorStringTransport(self._proto2.my_peer, address1, address2)
self.tr2 = HathorStringTransport(self._proto1.my_peer, address2, address1)
self._proto1.makeConnection(self.tr1)
self._proto2.makeConnection(self.tr2)
self.is_connected = True
Expand Down
4 changes: 3 additions & 1 deletion tests/others/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock

import pytest
from twisted.internet.address import IPv4Address

from hathor.p2p.manager import PeerConnectionsMetrics
from hathor.p2p.peer_id import PeerId
Expand Down Expand Up @@ -214,7 +215,8 @@ def build_hathor_protocol():
my_peer=my_peer,
p2p_manager=manager.connections,
use_ssl=False,
inbound=False
inbound=False,
remote_address=IPv4Address('TCP', '192.168.0.1', 5000),
)
protocol.peer = PeerId()

Expand Down
26 changes: 26 additions & 0 deletions tests/p2p/test_max_conn_per_ip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from twisted.internet.address import IPv4Address

from hathor.simulator import FakeConnection
from tests.simulation.base import SimulatorTestCase


class PeerRelayTestCase(SimulatorTestCase):
__test__ = True

def test_max_conn_per_ip(self) -> None:
m0 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)

max_connections_per_ip = m0.connections.max_connections_per_ip
for i in range(1, max_connections_per_ip + 8):
m1 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)

address = IPv4Address('TCP', '127.0.0.1', 1234 + i)
conn = FakeConnection(m0, m1, latency=0.05, address2=address)
self.simulator.add_connection(conn)

self.simulator.run(10)

if i <= max_connections_per_ip:
self.assertFalse(conn.tr1.disconnected)
else:
self.assertTrue(conn.tr1.disconnected)