From fe84becadd71bf0915436cdc6ea910fad8d98974 Mon Sep 17 00:00:00 2001 From: Gabriel Levcovitz Date: Tue, 5 Nov 2024 15:00:08 -0300 Subject: [PATCH] refactor(p2p): refactor peer address handling --- hathor/builder/cli_builder.py | 4 +- hathor/p2p/entrypoint.py | 215 --------------- hathor/p2p/manager.py | 48 ++-- hathor/p2p/peer.py | 67 ++--- hathor/p2p/peer_discovery/bootstrap.py | 8 +- hathor/p2p/peer_discovery/dns.py | 38 +-- hathor/p2p/peer_discovery/peer_discovery.py | 4 +- hathor/p2p/peer_endpoint.py | 277 ++++++++++++++++++++ hathor/p2p/protocol.py | 61 +++-- hathor/p2p/resources/add_peers.py | 13 +- hathor/p2p/states/peer_id.py | 24 +- hathor/p2p/utils.py | 4 +- hathor/simulator/fake_connection.py | 68 ++++- tests/others/test_metrics.py | 4 +- tests/p2p/test_bootstrap.py | 7 +- tests/p2p/test_connections.py | 7 +- tests/p2p/test_peer_id.py | 74 +++++- tests/p2p/test_protocol.py | 201 +++++++++++++- tests/resources/p2p/test_add_peer.py | 4 +- tests/resources/p2p/test_status.py | 12 +- 20 files changed, 747 insertions(+), 393 deletions(-) delete mode 100644 hathor/p2p/entrypoint.py create mode 100644 hathor/p2p/peer_endpoint.py diff --git a/hathor/builder/cli_builder.py b/hathor/builder/cli_builder.py index b2ffc747a..464d9b319 100644 --- a/hathor/builder/cli_builder.py +++ b/hathor/builder/cli_builder.py @@ -34,9 +34,9 @@ from hathor.indexes import IndexesManager, MemoryIndexesManager, RocksDBIndexesManager from hathor.manager import HathorManager from hathor.mining.cpu_mining_service import CpuMiningService -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.manager import ConnectionsManager from hathor.p2p.peer import PrivatePeer +from hathor.p2p.peer_endpoint import PeerEndpoint from hathor.p2p.utils import discover_hostname, get_genesis_short_hash from hathor.pubsub import PubSubManager from hathor.reactor import ReactorProtocol as Reactor @@ -420,7 +420,7 @@ def create_manager(self, reactor: Reactor) -> HathorManager: p2p_manager.add_peer_discovery(DNSPeerDiscovery(dns_hosts)) if self._args.bootstrap: - entrypoints = [Entrypoint.parse(desc) for desc in self._args.bootstrap] + entrypoints = [PeerEndpoint.parse(desc) for desc in self._args.bootstrap] p2p_manager.add_peer_discovery(BootstrapPeerDiscovery(entrypoints)) if self._args.x_rocksdb_indexes: diff --git a/hathor/p2p/entrypoint.py b/hathor/p2p/entrypoint.py deleted file mode 100644 index 23ead1199..000000000 --- a/hathor/p2p/entrypoint.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2024 Hathor Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from enum import Enum -from urllib.parse import parse_qs, urlparse - -from twisted.internet.address import IPv4Address, IPv6Address -from twisted.internet.endpoints import TCP4ClientEndpoint -from twisted.internet.interfaces import IStreamClientEndpoint -from typing_extensions import Self - -from hathor.p2p.peer_id import PeerId -from hathor.reactor import ReactorProtocol as Reactor - - -class Protocol(Enum): - TCP = 'tcp' - - -@dataclass(frozen=True, slots=True) -class Entrypoint: - """Endpoint description (returned from DNS query, or received from the p2p network) may contain a peer-id.""" - - protocol: Protocol - host: str - port: int - peer_id: PeerId | None = None - - def __str__(self): - if self.peer_id is None: - return f'{self.protocol.value}://{self.host}:{self.port}' - else: - return f'{self.protocol.value}://{self.host}:{self.port}/?id={self.peer_id}' - - @classmethod - def parse(cls, description: str) -> Self: - """Parse endpoint description into an Entrypoint object. - - Examples: - - >>> str(Entrypoint.parse('tcp://127.0.0.1:40403/')) - 'tcp://127.0.0.1:40403' - - >>> id1 = 'c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' - >>> Entrypoint.parse(f'tcp://127.0.0.1:40403/?id={id1}') - Entrypoint(protocol=, host='127.0.0.1', port=40403, \ -peer_id=PeerId('c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696')) - - >>> str(Entrypoint.parse(f'tcp://127.0.0.1:40403/?id={id1}')) - 'tcp://127.0.0.1:40403/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' - - >>> Entrypoint.parse('tcp://127.0.0.1:40403') - Entrypoint(protocol=, host='127.0.0.1', port=40403, peer_id=None) - - >>> Entrypoint.parse('tcp://127.0.0.1:40403/') - Entrypoint(protocol=, host='127.0.0.1', port=40403, peer_id=None) - - >>> Entrypoint.parse('tcp://foo.bar.baz:40403/') - Entrypoint(protocol=, host='foo.bar.baz', port=40403, peer_id=None) - - >>> str(Entrypoint.parse('tcp://foo.bar.baz:40403/')) - 'tcp://foo.bar.baz:40403' - - >>> Entrypoint.parse('tcp://127.0.0.1:40403/?id=123') - Traceback (most recent call last): - ... - ValueError: non-hexadecimal number found in fromhex() arg at position 3 - - >>> Entrypoint.parse('tcp://127.0.0.1:4040f') - Traceback (most recent call last): - ... - ValueError: Port could not be cast to integer value as '4040f' - - >>> Entrypoint.parse('udp://127.0.0.1:40403/') - Traceback (most recent call last): - ... - ValueError: 'udp' is not a valid Protocol - - >>> Entrypoint.parse('tcp://127.0.0.1/') - Traceback (most recent call last): - ... - ValueError: expected a port - - >>> Entrypoint.parse('tcp://:40403/') - Traceback (most recent call last): - ... - ValueError: expected a host - - >>> Entrypoint.parse('tcp://127.0.0.1:40403/foo') - Traceback (most recent call last): - ... - ValueError: unexpected path: /foo - - >>> id2 = 'bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872' - >>> Entrypoint.parse(f'tcp://127.0.0.1:40403/?id={id1}&id={id2}') - Traceback (most recent call last): - ... - ValueError: unexpected id count: 2 - """ - url = urlparse(description) - protocol = Protocol(url.scheme) - host = url.hostname - if host is None: - raise ValueError('expected a host') - port = url.port - if port is None: - raise ValueError('expected a port') - if url.path not in {'', '/'}: - raise ValueError(f'unexpected path: {url.path}') - peer_id: PeerId | None = None - - if url.query: - query = parse_qs(url.query) - if 'id' in query: - ids = query['id'] - if len(ids) != 1: - raise ValueError(f'unexpected id count: {len(ids)}') - peer_id = PeerId(ids[0]) - - return cls(protocol, host, port, peer_id) - - @classmethod - def from_hostname_address(cls, hostname: str, address: IPv4Address | IPv6Address) -> Self: - return cls.parse(f'{address.type}://{hostname}:{address.port}') - - def to_client_endpoint(self, reactor: Reactor) -> IStreamClientEndpoint: - """This method generates a twisted client endpoint that has a .connect() method.""" - # XXX: currently we don't support IPv6, but when we do we have to decide between TCP4ClientEndpoint and - # TCP6ClientEndpoint, when the host is an IP address that is easy, but when it is a DNS hostname, we will not - # know which to use until we know which resource records it holds (A or AAAA) - return TCP4ClientEndpoint(reactor, self.host, self.port) - - def equals_ignore_peer_id(self, other: Self) -> bool: - """Compares `self` and `other` ignoring the `peer_id` fields of either. - - Examples: - - >>> ep1 = 'tcp://foo:111' - >>> ep2 = 'tcp://foo:111/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' - >>> ep3 = 'tcp://foo:111/?id=bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872' - >>> ep4 = 'tcp://bar:111/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' - >>> ep5 = 'tcp://foo:112/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' - >>> Entrypoint.parse(ep1).equals_ignore_peer_id(Entrypoint.parse(ep2)) - True - >>> Entrypoint.parse(ep2).equals_ignore_peer_id(Entrypoint.parse(ep3)) - True - >>> Entrypoint.parse(ep1).equals_ignore_peer_id(Entrypoint.parse(ep4)) - False - >>> Entrypoint.parse(ep2).equals_ignore_peer_id(Entrypoint.parse(ep4)) - False - >>> Entrypoint.parse(ep2).equals_ignore_peer_id(Entrypoint.parse(ep5)) - False - """ - return (self.protocol, self.host, self.port) == (other.protocol, other.host, other.port) - - def peer_id_conflicts_with(self, other: Self) -> bool: - """Returns True if both self and other have a peer_id and they are different, returns False otherwise. - - This method ignores the host. Which is useful for catching the cases where both `self` and `other` have a - declared `peer_id` and they are not equal. - - >>> desc_no_pid = 'tcp://127.0.0.1:40403/' - >>> ep_no_pid = Entrypoint.parse(desc_no_pid) - >>> desc_pid1 = 'tcp://127.0.0.1:40403/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' - >>> ep_pid1 = Entrypoint.parse(desc_pid1) - >>> desc_pid2 = 'tcp://127.0.0.1:40403/?id=bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872' - >>> ep_pid2 = Entrypoint.parse(desc_pid2) - >>> desc2_pid2 = 'tcp://foo.bar:40403/?id=bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872' - >>> ep2_pid2 = Entrypoint.parse(desc2_pid2) - >>> ep_no_pid.peer_id_conflicts_with(ep_no_pid) - False - >>> ep_no_pid.peer_id_conflicts_with(ep_pid1) - False - >>> ep_pid1.peer_id_conflicts_with(ep_no_pid) - False - >>> ep_pid1.peer_id_conflicts_with(ep_pid2) - True - >>> ep_pid1.peer_id_conflicts_with(ep2_pid2) - True - >>> ep_pid2.peer_id_conflicts_with(ep2_pid2) - False - """ - return self.peer_id is not None and other.peer_id is not None and self.peer_id != other.peer_id - - def is_localhost(self) -> bool: - """Used to determine if the entrypoint host is a localhost address. - - Examples: - - >>> Entrypoint.parse('tcp://127.0.0.1:444').is_localhost() - True - >>> Entrypoint.parse('tcp://localhost:444').is_localhost() - True - >>> Entrypoint.parse('tcp://8.8.8.8:444').is_localhost() - False - >>> Entrypoint.parse('tcp://foo.bar:444').is_localhost() - False - """ - if self.host == '127.0.0.1': - return True - if self.host == 'localhost': - return True - return False diff --git a/hathor/p2p/manager.py b/hathor/p2p/manager.py index 1b94f92d8..d53c7be83 100644 --- a/hathor/p2p/manager.py +++ b/hathor/p2p/manager.py @@ -25,10 +25,10 @@ from twisted.web.client import Agent from hathor.conf.settings import HathorSettings -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.netfilter.factory import NetfilterFactory from hathor.p2p.peer import PrivatePeer, PublicPeer, UnverifiedPeer from hathor.p2p.peer_discovery import PeerDiscovery +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.peer_id import PeerId from hathor.p2p.peer_storage import UnverifiedPeerStorage, VerifiedPeerStorage from hathor.p2p.protocol import HathorProtocol @@ -60,7 +60,7 @@ class _SyncRotateInfo(NamedTuple): class _ConnectingPeer(NamedTuple): - entrypoint: Entrypoint + entrypoint: PeerEndpoint endpoint_deferred: Deferred @@ -370,7 +370,7 @@ def on_connection_failure(self, failure: Failure, peer: Optional[UnverifiedPeer endpoint: IStreamClientEndpoint) -> None: connecting_peer = self.connecting_peers[endpoint] entrypoint = connecting_peer.entrypoint - self.log.warn('connection failure', entrypoint=entrypoint, failure=failure.getErrorMessage()) + self.log.warn('connection failure', entrypoint=str(entrypoint), failure=failure.getErrorMessage()) self.connecting_peers.pop(endpoint) self.pubsub.publish( @@ -475,7 +475,7 @@ def iter_ready_connections(self) -> Iterable[HathorProtocol]: for conn in self.connected_peers.values(): yield conn - def iter_not_ready_endpoints(self) -> Iterable[Entrypoint]: + def iter_not_ready_endpoints(self) -> Iterable[PeerEndpoint]: """Iterate over not-ready connections.""" for connecting_peer in self.connecting_peers.values(): yield connecting_peer.entrypoint @@ -589,27 +589,28 @@ def connect_to_if_not_connected(self, peer: UnverifiedPeer | PublicPeer, now: in assert peer.id is not None if peer.info.can_retry(now): - self.connect_to(self.rng.choice(peer.info.entrypoints), peer) + addr = self.rng.choice(peer.info.entrypoints) + self.connect_to(addr.with_id(peer.id), peer) def _connect_to_callback( self, protocol: IProtocol, - peer: Optional[UnverifiedPeer | PublicPeer], + peer: UnverifiedPeer | PublicPeer | None, endpoint: IStreamClientEndpoint, - entrypoint: Entrypoint, + entrypoint: PeerEndpoint, ) -> None: """Called when we successfully connect to a peer.""" if isinstance(protocol, HathorProtocol): - protocol.on_outbound_connect(entrypoint) + protocol.on_outbound_connect(entrypoint, peer) else: assert isinstance(protocol, TLSMemoryBIOProtocol) assert isinstance(protocol.wrappedProtocol, HathorProtocol) - protocol.wrappedProtocol.on_outbound_connect(entrypoint) + protocol.wrappedProtocol.on_outbound_connect(entrypoint, peer) self.connecting_peers.pop(endpoint) def connect_to( self, - entrypoint: Entrypoint, + entrypoint: PeerEndpoint, peer: UnverifiedPeer | PublicPeer | None = None, use_ssl: bool | None = None, ) -> None: @@ -618,24 +619,27 @@ def connect_to( If `use_ssl` is True, then the connection will be wraped by a TLS. """ - if entrypoint.peer_id is not None and peer is not None and str(entrypoint.peer_id) != peer.id: + if entrypoint.peer_id is not None and peer is not None and entrypoint.peer_id != peer.id: self.log.debug('skipping because the entrypoint peer_id does not match the actual peer_id', - entrypoint=entrypoint) + entrypoint=str(entrypoint)) return for connecting_peer in self.connecting_peers.values(): - if connecting_peer.entrypoint.equals_ignore_peer_id(entrypoint): - self.log.debug('skipping because we are already connecting to this endpoint', entrypoint=entrypoint) + if connecting_peer.entrypoint.addr == entrypoint.addr: + self.log.debug( + 'skipping because we are already connecting to this endpoint', + entrypoint=str(entrypoint), + ) return - if self.localhost_only and not entrypoint.is_localhost(): - self.log.debug('skip because of simple localhost check', entrypoint=entrypoint) + if self.localhost_only and not entrypoint.addr.is_localhost(): + self.log.debug('skip because of simple localhost check', entrypoint=str(entrypoint)) return if use_ssl is None: use_ssl = self.use_ssl - endpoint = entrypoint.to_client_endpoint(self.reactor) + endpoint = entrypoint.addr.to_client_endpoint(self.reactor) factory: IProtocolFactory if use_ssl: @@ -650,9 +654,9 @@ def connect_to( deferred = endpoint.connect(factory) self.connecting_peers[endpoint] = _ConnectingPeer(entrypoint, deferred) - deferred.addCallback(self._connect_to_callback, peer, endpoint, entrypoint) # type: ignore - deferred.addErrback(self.on_connection_failure, peer, endpoint) # type: ignore - self.log.info('connect to', entrypoint=str(entrypoint), peer=str(peer)) + deferred.addCallback(self._connect_to_callback, peer, endpoint, entrypoint) + deferred.addErrback(self.on_connection_failure, peer, endpoint) + self.log.info('connecting to', entrypoint=str(entrypoint), peer=str(peer)) self.pubsub.publish( HathorEvents.NETWORK_PEER_CONNECTING, peer=peer, @@ -708,13 +712,13 @@ def update_hostname_entrypoints(self, *, old_hostname: str | None, new_hostname: assert self.manager is not None for address in self._listen_addresses: if old_hostname is not None: - old_entrypoint = Entrypoint.from_hostname_address(old_hostname, address) + old_entrypoint = PeerAddress.from_hostname_address(old_hostname, address) if old_entrypoint in self.my_peer.info.entrypoints: self.my_peer.info.entrypoints.remove(old_entrypoint) self._add_hostname_entrypoint(new_hostname, address) def _add_hostname_entrypoint(self, hostname: str, address: IPv4Address | IPv6Address) -> None: - hostname_entrypoint = Entrypoint.from_hostname_address(hostname, address) + hostname_entrypoint = PeerAddress.from_hostname_address(hostname, address) self.my_peer.info.entrypoints.append(hostname_entrypoint) def get_connection_to_drop(self, protocol: HathorProtocol) -> HathorProtocol: diff --git a/hathor/p2p/peer.py b/hathor/p2p/peer.py index 33aa617db..53f43369d 100644 --- a/hathor/p2p/peer.py +++ b/hathor/p2p/peer.py @@ -55,7 +55,7 @@ from hathor.conf.get_settings import get_global_settings from hathor.conf.settings import HathorSettings from hathor.daa import DifficultyAdjustmentAlgorithm -from hathor.p2p.entrypoint import Entrypoint +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.peer_id import PeerId from hathor.p2p.utils import discover_dns, generate_certificate from hathor.util import not_none @@ -74,14 +74,6 @@ class PeerFlags(str, Enum): RETRIES_EXCEEDED = 'retries_exceeded' -def _parse_entrypoint(entrypoint_string: str) -> Entrypoint: - """ Helper function to parse an entrypoint from string.""" - entrypoint = Entrypoint.parse(entrypoint_string) - if entrypoint.peer_id is not None: - raise ValueError('do not add id= to peer.json entrypoints') - return entrypoint - - def _parse_pubkey(pubkey_string: str) -> rsa.RSAPublicKey: """ Helper function to parse a public key from string.""" public_key_der = base64.b64decode(pubkey_string) @@ -114,7 +106,7 @@ class PeerInfo: """ Stores entrypoint and connection attempts information. """ - entrypoints: list[Entrypoint] = field(default_factory=list) + entrypoints: list[PeerAddress] = field(default_factory=list) retry_timestamp: int = 0 # should only try connecting to this peer after this timestamp retry_interval: int = 5 # how long to wait for next connection retry. It will double for each failure retry_attempts: int = 0 # how many retries were made @@ -136,13 +128,11 @@ def _merge(self, other: PeerInfo) -> None: async def validate_entrypoint(self, protocol: HathorProtocol) -> bool: """ Validates if connection entrypoint is one of the peer entrypoints """ - found_entrypoint = False - # If has no entrypoints must be behind a NAT, so we add the flag to the connection if len(self.entrypoints) == 0: protocol.warning_flags.add(protocol.WarningFlags.NO_ENTRYPOINTS) # If there are no entrypoints, we don't need to validate it - found_entrypoint = True + return True # Entrypoint validation with connection string and connection host # Entrypoints have the format tcp://IP|name:port @@ -150,19 +140,13 @@ async def validate_entrypoint(self, protocol: HathorProtocol) -> bool: if protocol.entrypoint is not None: # Connection string has the format tcp://IP:port # So we must consider that the entrypoint could be in name format - if protocol.entrypoint.equals_ignore_peer_id(entrypoint): - # XXX: wrong peer-id should not make it into self.entrypoints - assert not protocol.entrypoint.peer_id_conflicts_with(entrypoint), 'wrong peer-id was added before' - # Found the entrypoint - found_entrypoint = True - break + if protocol.entrypoint.addr == entrypoint: + return True # TODO: don't use `daa.TEST_MODE` for this test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE result = await discover_dns(entrypoint.host, test_mode) - if protocol.entrypoint in result: - # Found the entrypoint - found_entrypoint = True - break + if protocol.entrypoint.addr in [endpoint.addr for endpoint in result]: + return True else: # When the peer is the server part of the connection we don't have the full entrypoint description # So we can only validate the host from the protocol @@ -174,20 +158,13 @@ async def validate_entrypoint(self, protocol: HathorProtocol) -> bool: # Connection host has only the IP # So we must consider that the entrypoint could be in name format and we just validate the host if connection_host == entrypoint.host: - found_entrypoint = True - break + return True test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE result = await discover_dns(entrypoint.host, test_mode) - if connection_host in [entrypoint.host for entrypoint in result]: - # Found the entrypoint - found_entrypoint = True - break + if connection_host in [entrypoint.addr.host for entrypoint in result]: + return True - if not found_entrypoint: - # In case the validation fails - return False - - return True + return False def increment_retry_attempt(self, now: int) -> None: """ Updates timestamp for next retry. @@ -242,9 +219,20 @@ def create_from_json(cls, data: dict[str, Any]) -> Self: It is to create an UnverifiedPeer from a peer connection. """ + peer_id = PeerId(data['id']) + endpoints = [] + + for endpoint_str in data.get('entrypoints', []): + # We have to parse using PeerEndpoint to be able to support older peers that still + # send the id in entrypoints, but we validate that they're sending the correct id. + endpoint = PeerEndpoint.parse(endpoint_str) + if endpoint.peer_id is not None and endpoint.peer_id != peer_id: + raise ValueError(f'conflicting peer_id: {endpoint.peer_id} != {peer_id}') + endpoints.append(endpoint.addr) + return cls( - id=PeerId(data['id']), - info=PeerInfo(entrypoints=[_parse_entrypoint(e) for e in data.get('entrypoints', [])]), + id=peer_id, + info=PeerInfo(entrypoints=endpoints), ) def merge(self, other: UnverifiedPeer) -> None: @@ -364,12 +352,7 @@ def verify_signature(self, signature: bytes, data: bytes) -> bool: return True def validate(self) -> None: - """ Return `True` if the following conditions are valid: - (i) public key and private key matches; - (ii) the id matches with the public key. - - TODO(epnichols): Update docs. Only raises exceptions; doesn't return anything. - """ + """Calculate the PeerId based on the public key and raise an exception if it does not match.""" if self.id != self.calculate_id(): raise InvalidPeerIdException('id does not match public key') diff --git a/hathor/p2p/peer_discovery/bootstrap.py b/hathor/p2p/peer_discovery/bootstrap.py index a30970ae2..55b5e9f16 100644 --- a/hathor/p2p/peer_discovery/bootstrap.py +++ b/hathor/p2p/peer_discovery/bootstrap.py @@ -17,7 +17,7 @@ from structlog import get_logger from typing_extensions import override -from hathor.p2p.entrypoint import Entrypoint +from hathor.p2p.peer_endpoint import PeerEndpoint from .peer_discovery import PeerDiscovery @@ -28,15 +28,15 @@ class BootstrapPeerDiscovery(PeerDiscovery): """ It implements a bootstrap peer discovery, which receives a static list of peers. """ - def __init__(self, entrypoints: list[Entrypoint]): + def __init__(self, entrypoints: list[PeerEndpoint]): """ - :param descriptions: Descriptions of peers to connect to. + :param entrypoints: Addresses of peers to connect to. """ super().__init__() self.log = logger.new() self.entrypoints = entrypoints @override - async def discover_and_connect(self, connect_to: Callable[[Entrypoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: for entrypoint in self.entrypoints: connect_to(entrypoint) diff --git a/hathor/p2p/peer_discovery/dns.py b/hathor/p2p/peer_discovery/dns.py index b946fc9eb..c5dfe74d6 100644 --- a/hathor/p2p/peer_discovery/dns.py +++ b/hathor/p2p/peer_discovery/dns.py @@ -15,7 +15,7 @@ import socket from collections.abc import Iterator from itertools import chain -from typing import Callable, TypeAlias, cast +from typing import Callable, TypeAlias from structlog import get_logger from twisted.internet.defer import Deferred, gatherResults @@ -23,7 +23,7 @@ from twisted.names.dns import Record_A, Record_TXT, RRHeader from typing_extensions import override -from hathor.p2p.entrypoint import Entrypoint, Protocol +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint, Protocol from .peer_discovery import PeerDiscovery @@ -53,7 +53,7 @@ def do_lookup_text(self, host: str) -> Deferred[LookupResult]: return lookupText(host) @override - async def discover_and_connect(self, connect_to: Callable[[Entrypoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: """ Run DNS lookup for host and connect to it This is executed when starting the DNS Peer Discovery and first connecting to the network """ @@ -61,26 +61,26 @@ async def discover_and_connect(self, connect_to: Callable[[Entrypoint], None]) - for entrypoint in (await self.dns_seed_lookup(host)): connect_to(entrypoint) - async def dns_seed_lookup(self, host: str) -> set[Entrypoint]: + async def dns_seed_lookup(self, host: str) -> set[PeerEndpoint]: """ Run a DNS lookup for TXT, A, and AAAA records and return a list of connection strings. """ if self.test_mode: # Useful for testing purposes, so we don't need to execute a DNS query - return {Entrypoint.parse('tcp://127.0.0.1:40403')} + return {PeerEndpoint.parse('tcp://127.0.0.1:40403')} - deferreds = [] + deferreds: list[Deferred[Iterator[PeerEndpoint]]] = [] - d1 = self.do_lookup_text(host) - d1.addCallback(self.dns_seed_lookup_text) - d1.addErrback(self.errback) - deferreds.append(cast(Deferred[Iterator[Entrypoint]], d1)) # mypy doesn't know how addCallback affects d1 + d1 = self.do_lookup_text(host) \ + .addCallback(self.dns_seed_lookup_text) \ + .addErrback(self.errback) + deferreds.append(d1) - d2 = self.do_lookup_address(host) - d2.addCallback(self.dns_seed_lookup_address) - d2.addErrback(self.errback) - deferreds.append(cast(Deferred[Iterator[Entrypoint]], d2)) # mypy doesn't know how addCallback affects d2 + d2 = self.do_lookup_address(host) \ + .addCallback(self.dns_seed_lookup_address) \ + .addErrback(self.errback) + deferreds.append(d2) - results: list[Iterator[Entrypoint]] = await gatherResults(deferreds) + results: list[Iterator[PeerEndpoint]] = await gatherResults(deferreds) return set(chain(*results)) def errback(self, result): @@ -89,7 +89,7 @@ def errback(self, result): self.log.error('errback', result=result) return [] - def dns_seed_lookup_text(self, results: LookupResult) -> Iterator[Entrypoint]: + def dns_seed_lookup_text(self, results: LookupResult) -> Iterator[PeerEndpoint]: """ Run a DNS lookup for TXT records to discover new peers. The `results` has three lists that contain answer records, authority records, and additional records. @@ -100,14 +100,14 @@ def dns_seed_lookup_text(self, results: LookupResult) -> Iterator[Entrypoint]: for txt in record.payload.data: raw_entrypoint = txt.decode('utf-8') try: - entrypoint = Entrypoint.parse(raw_entrypoint) + entrypoint = PeerEndpoint.parse(raw_entrypoint) except ValueError: self.log.warning('could not parse entrypoint, skipping it', raw_entrypoint=raw_entrypoint) continue self.log.info('seed DNS TXT found', entrypoint=str(entrypoint)) yield entrypoint - def dns_seed_lookup_address(self, results: LookupResult) -> Iterator[Entrypoint]: + def dns_seed_lookup_address(self, results: LookupResult) -> Iterator[PeerEndpoint]: """ Run a DNS lookup for A records to discover new peers. The `results` has three lists that contain answer records, authority records, and additional records. @@ -118,6 +118,6 @@ def dns_seed_lookup_address(self, results: LookupResult) -> Iterator[Entrypoint] address = record.payload.address assert address is not None host = socket.inet_ntoa(address) - entrypoint = Entrypoint(Protocol.TCP, host, self.default_port) + entrypoint = PeerAddress(Protocol.TCP, host, self.default_port).with_id() self.log.info('seed DNS A found', entrypoint=str(entrypoint)) yield entrypoint diff --git a/hathor/p2p/peer_discovery/peer_discovery.py b/hathor/p2p/peer_discovery/peer_discovery.py index a6ff799ed..7d040fae2 100644 --- a/hathor/p2p/peer_discovery/peer_discovery.py +++ b/hathor/p2p/peer_discovery/peer_discovery.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import Callable -from hathor.p2p.entrypoint import Entrypoint +from hathor.p2p.peer_endpoint import PeerEndpoint class PeerDiscovery(ABC): @@ -23,7 +23,7 @@ class PeerDiscovery(ABC): """ @abstractmethod - async def discover_and_connect(self, connect_to: Callable[[Entrypoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: """ This method must discover the peers and call `connect_to` for each of them. :param connect_to: Function which will be called for each discovered peer. diff --git a/hathor/p2p/peer_endpoint.py b/hathor/p2p/peer_endpoint.py new file mode 100644 index 000000000..c7cafce20 --- /dev/null +++ b/hathor/p2p/peer_endpoint.py @@ -0,0 +1,277 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any +from urllib.parse import parse_qs, urlparse + +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.endpoints import TCP4ClientEndpoint +from twisted.internet.interfaces import IAddress, IStreamClientEndpoint +from typing_extensions import Self + +from hathor.p2p.peer_id import PeerId +from hathor.reactor import ReactorProtocol as Reactor + +COMPARISON_ERROR_MESSAGE = ( + 'never compare PeerAddress with PeerEndpoint or two PeerEndpoint instances directly! ' + 'instead, compare the addr attribute explicitly, and if relevant, the peer_id too.' +) + + +class Protocol(Enum): + TCP = 'tcp' + + +@dataclass(frozen=True, slots=True) +class PeerAddress: + """Peer address as received when a connection is made.""" + + protocol: Protocol + host: str + port: int + + def __str__(self) -> str: + return f'{self.protocol.value}://{self.host}:{self.port}' + + def __eq__(self, other: Any) -> bool: + """ + This function implements strict comparison between two PeerAddress insteances. Comparison between a PeerAddress + and a PeerEndpoint, or between two PeerEndpoint instances, purposefully throws a ValueError. + + Instead, in those cases users should explicity compare the underlying PeerAddress instances using the `addr` + attribute. This ensures we don't have issues with implicit equality checks,such as when using the `in` operator + + Examples: + + >>> ep1 = 'tcp://foo:111' + >>> ep2 = 'tcp://foo:111/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' + >>> ep3 = 'tcp://foo:111/?id=bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872' + >>> ep4 = 'tcp://bar:111/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' + >>> ep5 = 'tcp://foo:112/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' + >>> ep6 = 'tcp://localhost:111' + >>> ep7 = 'tcp://127.0.0.1:111' + >>> PeerEndpoint.parse(ep1).addr == PeerEndpoint.parse(ep2).addr + True + >>> PeerEndpoint.parse(ep2).addr == PeerEndpoint.parse(ep3).addr + True + >>> PeerEndpoint.parse(ep1).addr == PeerEndpoint.parse(ep4).addr + False + >>> PeerEndpoint.parse(ep2).addr == PeerEndpoint.parse(ep4).addr + False + >>> PeerEndpoint.parse(ep2).addr == PeerEndpoint.parse(ep5).addr + False + >>> PeerEndpoint.parse(ep6).addr == PeerEndpoint.parse(ep7).addr + True + >>> PeerEndpoint.parse(ep1) == PeerEndpoint.parse(ep1) + Traceback (most recent call last): + ... + ValueError: never compare PeerAddress with PeerEndpoint or two PeerEndpoint instances directly! \ +instead, compare the addr attribute explicitly, and if relevant, the peer_id too. + >>> PeerEndpoint.parse(ep1) == PeerEndpoint.parse(ep1).addr + Traceback (most recent call last): + ... + ValueError: never compare PeerAddress with PeerEndpoint or two PeerEndpoint instances directly! \ +instead, compare the addr attribute explicitly, and if relevant, the peer_id too. + >>> PeerEndpoint.parse(ep1).addr == PeerEndpoint.parse(ep1) + Traceback (most recent call last): + ... + ValueError: never compare PeerAddress with PeerEndpoint or two PeerEndpoint instances directly! \ +instead, compare the addr attribute explicitly, and if relevant, the peer_id too. + >>> PeerEndpoint.parse(ep1) != PeerEndpoint.parse(ep4).addr + Traceback (most recent call last): + ... + ValueError: never compare PeerAddress with PeerEndpoint or two PeerEndpoint instances directly! \ +instead, compare the addr attribute explicitly, and if relevant, the peer_id too. + >>> PeerEndpoint.parse(ep1) in [PeerEndpoint.parse(ep1)] + Traceback (most recent call last): + ... + ValueError: never compare PeerAddress with PeerEndpoint or two PeerEndpoint instances directly! \ +instead, compare the addr attribute explicitly, and if relevant, the peer_id too. + >>> PeerEndpoint.parse(ep1).addr in [PeerEndpoint.parse(ep1).addr] + True + >>> PeerEndpoint.parse(ep1).addr != PeerEndpoint.parse(ep4).addr + True + """ + if not isinstance(other, PeerAddress): + raise ValueError(COMPARISON_ERROR_MESSAGE) + + if self.is_localhost() and other.is_localhost(): + return (self.protocol, self.port) == (other.protocol, other.port) + + return (self.protocol, self.host, self.port) == (other.protocol, other.host, other.port) + + def __ne__(self, other: Any) -> bool: + return not self == other + + @classmethod + def parse(cls, description: str) -> Self: + protocol, host, port, query = _parse_address_parts(description) + if query: + raise ValueError(f'unexpected query: "{description}". did you incorrectly add an id=?') + return cls(protocol, host, port) + + @classmethod + def from_hostname_address(cls, hostname: str, address: IPv4Address | IPv6Address) -> Self: + return cls.parse(f'{address.type}://{hostname}:{address.port}') + + @classmethod + def from_address(cls, address: IAddress) -> Self: + """Create an Entrypoint from a Twisted IAddress.""" + if not isinstance(address, (IPv4Address, IPv6Address)): + raise NotImplementedError(f'address: {address}') + return cls.parse(f'{address.type}://{address.host}:{address.port}') + + def to_client_endpoint(self, reactor: Reactor) -> IStreamClientEndpoint: + """This method generates a twisted client endpoint that has a .connect() method.""" + # XXX: currently we don't support IPv6, but when we do we have to decide between TCP4ClientEndpoint and + # TCP6ClientEndpoint, when the host is an IP address that is easy, but when it is a DNS hostname, we will not + # know which to use until we know which resource records it holds (A or AAAA) + return TCP4ClientEndpoint(reactor, self.host, self.port) + + def is_localhost(self) -> bool: + """Used to determine if the address host is a localhost address. + + Examples: + + >>> PeerAddress.parse('tcp://127.0.0.1:444').is_localhost() + True + >>> PeerAddress.parse('tcp://localhost:444').is_localhost() + True + >>> PeerAddress.parse('tcp://8.8.8.8:444').is_localhost() + False + >>> PeerAddress.parse('tcp://foo.bar:444').is_localhost() + False + """ + return self.host in ('127.0.0.1', 'localhost') + + def with_id(self, peer_id: PeerId | None = None) -> PeerEndpoint: + """Create a PeerEndpoint instance with self as the address and with the provided peer_id, or None.""" + return PeerEndpoint(self, peer_id) + + +@dataclass(frozen=True, slots=True) +class PeerEndpoint: + """Peer endpoint description (returned from DNS query, or received from the p2p network) may contain a peer-id.""" + + addr: PeerAddress + peer_id: PeerId | None = None + + def __str__(self) -> str: + return str(self.addr) if self.peer_id is None else f'{self.addr}/?id={self.peer_id}' + + def __eq__(self, other: Any) -> bool: + """See PeerAddress.__eq__""" + raise ValueError(COMPARISON_ERROR_MESSAGE) + + def __ne__(self, other: Any) -> bool: + """See PeerAddress.__eq__""" + raise ValueError(COMPARISON_ERROR_MESSAGE) + + @classmethod + def parse(cls, description: str) -> PeerEndpoint: + """Parse endpoint description into an PeerEndpoint object. + + Examples: + + >>> str(PeerEndpoint.parse('tcp://127.0.0.1:40403/')) + 'tcp://127.0.0.1:40403' + + >>> id1 = 'c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' + >>> PeerEndpoint.parse(f'tcp://127.0.0.1:40403/?id={id1}') + PeerEndpoint(addr=PeerAddress(protocol=, host='127.0.0.1', port=40403), \ +peer_id=PeerId('c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696')) + + >>> str(PeerEndpoint.parse(f'tcp://127.0.0.1:40403/?id={id1}')) + 'tcp://127.0.0.1:40403/?id=c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696' + + >>> PeerEndpoint.parse('tcp://127.0.0.1:40403') + PeerEndpoint(addr=PeerAddress(protocol=, host='127.0.0.1', port=40403), peer_id=None) + + >>> PeerEndpoint.parse('tcp://127.0.0.1:40403/') + PeerEndpoint(addr=PeerAddress(protocol=, host='127.0.0.1', port=40403), peer_id=None) + + >>> PeerEndpoint.parse('tcp://foo.bar.baz:40403/') + PeerEndpoint(addr=PeerAddress(protocol=, host='foo.bar.baz', port=40403), \ +peer_id=None) + + >>> str(PeerEndpoint.parse('tcp://foo.bar.baz:40403/')) + 'tcp://foo.bar.baz:40403' + + >>> PeerEndpoint.parse('tcp://127.0.0.1:40403/?id=123') + Traceback (most recent call last): + ... + ValueError: non-hexadecimal number found in fromhex() arg at position 3 + + >>> PeerEndpoint.parse('tcp://127.0.0.1:4040f') + Traceback (most recent call last): + ... + ValueError: Port could not be cast to integer value as '4040f' + + >>> PeerEndpoint.parse('udp://127.0.0.1:40403/') + Traceback (most recent call last): + ... + ValueError: 'udp' is not a valid Protocol + + >>> PeerEndpoint.parse('tcp://127.0.0.1/') + Traceback (most recent call last): + ... + ValueError: expected a port: "tcp://127.0.0.1/" + + >>> PeerEndpoint.parse('tcp://:40403/') + Traceback (most recent call last): + ... + ValueError: expected a host: "tcp://:40403/" + + >>> PeerEndpoint.parse('tcp://127.0.0.1:40403/foo') + Traceback (most recent call last): + ... + ValueError: unexpected path: "tcp://127.0.0.1:40403/foo" + + >>> id2 = 'bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872' + >>> PeerEndpoint.parse(f'tcp://127.0.0.1:40403/?id={id1}&id={id2}') + Traceback (most recent call last): + ... + ValueError: unexpected id count: 2 + """ + protocol, host, port, query_str = _parse_address_parts(description) + peer_id: PeerId | None = None + + if query_str: + query = parse_qs(query_str) + if 'id' in query: + ids = query['id'] + if len(ids) != 1: + raise ValueError(f'unexpected id count: {len(ids)}') + peer_id = PeerId(ids[0]) + + return PeerAddress(protocol, host, port).with_id(peer_id) + + +def _parse_address_parts(description: str) -> tuple[Protocol, str, int, str]: + url = urlparse(description) + protocol = Protocol(url.scheme) + host = url.hostname + if host is None: + raise ValueError(f'expected a host: "{description}"') + port = url.port + if port is None: + raise ValueError(f'expected a port: "{description}"') + if url.path not in {'', '/'}: + raise ValueError(f'unexpected path: "{description}"') + + return protocol, host, port, url.query diff --git a/hathor/p2p/protocol.py b/hathor/p2p/protocol.py index f2bda9cc4..e05e63b55 100644 --- a/hathor/p2p/protocol.py +++ b/hathor/p2p/protocol.py @@ -14,9 +14,10 @@ import time from enum import Enum -from typing import TYPE_CHECKING, Any, Coroutine, Generator, Optional, cast +from typing import TYPE_CHECKING, Optional, cast from structlog import get_logger +from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.internet.interfaces import IDelayedCall, ITCPTransport, ITransport from twisted.internet.protocol import connectionDone @@ -24,9 +25,9 @@ from twisted.python.failure import Failure from hathor.conf.settings import HathorSettings -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.messages import ProtocolMessages -from hathor.p2p.peer import PrivatePeer, PublicPeer +from hathor.p2p.peer import PrivatePeer, PublicPeer, UnverifiedPeer +from hathor.p2p.peer_endpoint import PeerEndpoint from hathor.p2p.peer_id import PeerId from hathor.p2p.rate_limiter import RateLimiter from hathor.p2p.states import BaseState, HelloState, PeerIdState, ReadyState @@ -70,7 +71,6 @@ class RateLimitKeys(str, Enum): GLOBAL = 'global' class WarningFlags(str, Enum): - NO_PEER_ID_URL = 'no_peer_id_url' NO_ENTRYPOINTS = 'no_entrypoints' my_peer: PrivatePeer @@ -83,7 +83,7 @@ class WarningFlags(str, Enum): state: Optional[BaseState] connection_time: float _state_instances: dict[PeerState, BaseState] - entrypoint: Optional[Entrypoint] + entrypoint: Optional[PeerEndpoint] warning_flags: set[str] aborting: bool diff_timestamp: Optional[int] @@ -149,10 +149,10 @@ def __init__( # Connection string of the peer # Used to validate if entrypoints has this string - self.entrypoint: Optional[Entrypoint] = None + self.entrypoint: Optional[PeerEndpoint] = None # Peer id sent in the connection url that is expected to connect (optional) - self.expected_peer_id: Optional[str] = None + self.expected_peer_id: PeerId | None = None # Set of warning flags that may be added during the connection process self.warning_flags: set[str] = set() @@ -254,9 +254,13 @@ def on_connect(self) -> None: if self.connections: self.connections.on_peer_connect(self) - def on_outbound_connect(self, entrypoint: Entrypoint) -> None: + def on_outbound_connect(self, entrypoint: PeerEndpoint, peer: UnverifiedPeer | PublicPeer | None) -> None: """Called when we successfully establish an outbound connection to a peer.""" # Save the used entrypoint in protocol so we can validate that it matches the entrypoints data + if entrypoint.peer_id is not None and peer is not None: + assert entrypoint.peer_id == peer.id + + self.expected_peer_id = peer.id if peer else entrypoint.peer_id self.entrypoint = entrypoint def on_peer_ready(self) -> None: @@ -292,7 +296,7 @@ def send_message(self, cmd: ProtocolMessages, payload: Optional[str] = None) -> raise NotImplementedError @cpu.profiler(key=lambda self, cmd: 'p2p-cmd!{}'.format(str(cmd))) - def recv_message(self, cmd: ProtocolMessages, payload: str) -> Optional[Deferred[None]]: + def recv_message(self, cmd: ProtocolMessages, payload: str) -> None: """ Executed when a new message arrives. """ assert self.state is not None @@ -301,7 +305,6 @@ def recv_message(self, cmd: ProtocolMessages, payload: str) -> Optional[Deferred self.last_message = now if self._peer is not None: self.peer.info.last_seen = now - self.reset_idle_timeout() if not self.ratelimit.add_hit(self.RateLimitKeys.GLOBAL): # XXX: on Python 3.11 the result of the following expression: @@ -310,21 +313,22 @@ def recv_message(self, cmd: ProtocolMessages, payload: str) -> Optional[Deferred # that something like `str(value)` is called which results in a different value (usually not the case # for regular strings, but it is for enum+str), using `enum_variant.value` side-steps this problem self.state.send_throttle(self.RateLimitKeys.GLOBAL.value) - return None - - fn = self.state.cmd_map.get(cmd) - if fn is not None: - try: - result = fn(payload) - return Deferred.fromCoroutine(result) if isinstance(result, Coroutine) else result - except Exception: - self.log.warn('recv_message processing error', exc_info=True) - raise - else: + return + + cmd_handler = self.state.cmd_map.get(cmd) + if cmd_handler is None: self.log.debug('cmd not found', cmd=cmd, payload=payload, available=list(self.state.cmd_map.keys())) self.send_error_and_close_connection('Invalid Command: {} {}'.format(cmd, payload)) + return - return None + deferred_result: Deferred[None] = defer.maybeDeferred(cmd_handler, payload) + deferred_result \ + .addCallback(lambda _: self.reset_idle_timeout()) \ + .addErrback(self._on_cmd_handler_error, cmd) + + def _on_cmd_handler_error(self, failure: Failure, cmd: ProtocolMessages) -> None: + self.log.warn('recv_message processing error', reason=failure.getErrorMessage(), exc_info=True) + self.send_error_and_close_connection(f'Error processing "{cmd.value}" command') def send_error(self, msg: str) -> None: """ Send an error message to the peer. @@ -411,7 +415,7 @@ def lineLengthExceeded(self, line: str) -> None: super(HathorLineReceiver, self).lineLengthExceeded(line) @cpu.profiler(key=lambda self: 'p2p!{}'.format(self.get_short_remote())) - def lineReceived(self, line: bytes) -> Optional[Generator[Any, Any, None]]: + def lineReceived(self, line: bytes) -> None: assert self.transport is not None if self.aborting: @@ -420,7 +424,7 @@ def lineReceived(self, line: bytes) -> Optional[Generator[Any, Any, None]]: # abort and close the connection, HathorLineReceive.lineReceived will still be called for the buffered # lines. If that happens we just ignore those messages. self.log.debug('ignore received messager after abort') - return None + return self.metrics.received_messages += 1 self.metrics.received_bytes += len(line) @@ -429,17 +433,16 @@ def lineReceived(self, line: bytes) -> Optional[Generator[Any, Any, None]]: sline = line.decode('utf-8') except UnicodeDecodeError: self.transport.loseConnection() - return None + return msgtype, _, msgdata = sline.partition(' ') try: cmd = ProtocolMessages(msgtype) except ValueError: self.transport.loseConnection() - return None - else: - self.recv_message(cmd, msgdata) - return None + return + + self.recv_message(cmd, msgdata) def send_message(self, cmd_enum: ProtocolMessages, payload: Optional[str] = None) -> None: cmd = cmd_enum.value diff --git a/hathor/p2p/resources/add_peers.py b/hathor/p2p/resources/add_peers.py index aeb92208c..c8faeb5dc 100644 --- a/hathor/p2p/resources/add_peers.py +++ b/hathor/p2p/resources/add_peers.py @@ -20,8 +20,8 @@ from hathor.api_util import Resource, render_options, set_cors from hathor.cli.openapi_files.register import register_resource from hathor.manager import HathorManager -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.peer_discovery import BootstrapPeerDiscovery +from hathor.p2p.peer_endpoint import PeerEndpoint from hathor.util import json_dumpb, json_loadb @@ -60,7 +60,7 @@ def render_POST(self, request: Request) -> bytes: }) try: - entrypoints = list(map(Entrypoint.parse, raw_entrypoints)) + entrypoints = list(map(PeerEndpoint.parse, raw_entrypoints)) except ValueError: return json_dumpb({ 'success': False, @@ -69,14 +69,15 @@ def render_POST(self, request: Request) -> bytes: known_peers = self.manager.connections.verified_peer_storage.values() - def already_connected(entrypoint: Entrypoint) -> bool: + def already_connected(endpoint: PeerEndpoint) -> bool: # ignore peers that we're already trying to connect - if entrypoint in self.manager.connections.iter_not_ready_endpoints(): - return True + for ready_endpoint in self.manager.connections.iter_not_ready_endpoints(): + if endpoint.addr == ready_endpoint.addr: + return True # remove peers we already know about for peer in known_peers: - if entrypoint in peer.entrypoints: + if endpoint.addr in peer.info.entrypoints: return True return False diff --git a/hathor/p2p/states/peer_id.py b/hathor/p2p/states/peer_id.py index 2aca0a9db..77e8a051e 100644 --- a/hathor/p2p/states/peer_id.py +++ b/hathor/p2p/states/peer_id.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from structlog import get_logger @@ -63,17 +63,19 @@ def handle_ready(self, payload: str) -> None: # So it was just waiting for the ready message from the other peer to change the state to READY self.protocol.change_state(self.protocol.PeerState.READY) + def _get_peer_id_data(self) -> dict[str, Any]: + my_peer = self.protocol.my_peer + return dict( + id=str(my_peer.id), + pubKey=my_peer.get_public_key(), + entrypoints=my_peer.info.entrypoints_as_str(), + ) + def send_peer_id(self) -> None: """ Send a PEER-ID message, identifying the peer. """ - protocol = self.protocol - my_peer = protocol.my_peer - hello = { - 'id': str(my_peer.id), - 'pubKey': my_peer.get_public_key(), - 'entrypoints': my_peer.info.entrypoints_as_str(), - } - self.send_message(ProtocolMessages.PEER_ID, json_dumps(hello)) + data = self._get_peer_id_data() + self.send_message(ProtocolMessages.PEER_ID, json_dumps(data)) async def handle_peer_id(self, payload: str) -> None: """ Executed when a PEER-ID is received. It basically checks @@ -89,7 +91,6 @@ async def handle_peer_id(self, payload: str) -> None: data = json_loads(payload) peer = PublicPeer.create_from_json(data) - peer.validate() assert peer.id is not None # If the connection URL had a peer-id parameter we need to check it's the same @@ -119,6 +120,9 @@ async def handle_peer_id(self, payload: str) -> None: protocol.send_error_and_close_connection('Connection string is not in the entrypoints.') return + if protocol.entrypoint is not None and protocol.entrypoint.peer_id is not None: + assert protocol.entrypoint.peer_id == peer.id + if protocol.use_ssl: certificate_valid = peer.validate_certificate(protocol) if not certificate_valid: diff --git a/hathor/p2p/utils.py b/hathor/p2p/utils.py index c0a25f3d8..55f9b9591 100644 --- a/hathor/p2p/utils.py +++ b/hathor/p2p/utils.py @@ -29,8 +29,8 @@ from hathor.conf.get_settings import get_global_settings from hathor.conf.settings import HathorSettings from hathor.indexes.height_index import HeightInfo -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.peer_discovery import DNSPeerDiscovery +from hathor.p2p.peer_endpoint import PeerEndpoint from hathor.p2p.peer_id import PeerId from hathor.transaction.genesis import get_representation_for_all_genesis @@ -78,7 +78,7 @@ def get_settings_hello_dict(settings: HathorSettings) -> dict[str, Any]: return settings_dict -async def discover_dns(host: str, test_mode: int = 0) -> list[Entrypoint]: +async def discover_dns(host: str, test_mode: int = 0) -> list[PeerEndpoint]: """ Start a DNS peer discovery object and execute a search for the host Returns the DNS string from the requested host diff --git a/hathor/simulator/fake_connection.py b/hathor/simulator/fake_connection.py index c993302db..b3a29afc9 100644 --- a/hathor/simulator/fake_connection.py +++ b/hathor/simulator/fake_connection.py @@ -15,14 +15,16 @@ from __future__ import annotations from collections import deque -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional from OpenSSL.crypto import X509 from structlog import get_logger -from twisted.internet.address import HostnameAddress +from twisted.internet.address import IPv4Address from twisted.internet.testing import StringTransport from hathor.p2p.peer import PrivatePeer +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint +from hathor.p2p.peer_id import PeerId if TYPE_CHECKING: from hathor.manager import HathorManager @@ -32,8 +34,8 @@ class HathorStringTransport(StringTransport): - def __init__(self, peer: PrivatePeer): - super().__init__() + def __init__(self, peer: PrivatePeer, *, peer_address: IPv4Address): + super().__init__(peerAddress=peer_address) self._peer = peer @property @@ -46,12 +48,27 @@ def getPeerCertificate(self) -> X509: class FakeConnection: - def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, latency: float = 0, - autoreconnect: bool = False): + _next_port: int = 49000 + _port_per_manager: dict['HathorManager', int] = {} + + def __init__( + self, + manager1: 'HathorManager', + manager2: 'HathorManager', + *, + latency: float = 0, + autoreconnect: bool = False, + addr1: IPv4Address | None = None, + addr2: IPv4Address | None = None, + fake_bootstrap_id: PeerId | None | Literal[False] = False, + ): """ :param: latency: Latency between nodes in seconds + :fake_bootstrap_id: when False, bootstrap mode is disabled. When a PeerId or None are passed, bootstrap mode is + enabled and the value is used as the connection's entrypoint.peer_id """ self.log = logger.new() + self._fake_bootstrap_id = fake_bootstrap_id self.manager1 = manager1 self.manager2 = manager2 @@ -64,8 +81,28 @@ def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, late self._buf1: deque[str] = deque() self._buf2: deque[str] = deque() + # manager1's address, the server, where manager2 will connect to + self.addr1 = addr1 or IPv4Address('TCP', '127.0.0.1', self._get_port(manager1)) + # manager2's address, the client, where manager2 will connect from + self.addr2 = addr2 or IPv4Address('TCP', '127.0.0.1', self._get_port(manager2)) + self.reconnect() + @classmethod + def _get_port(cls, manager: 'HathorManager') -> int: + port = cls._port_per_manager.get(manager) + if port is None: + port = cls._next_port + cls._next_port += 1 + return port + + @property + def entrypoint(self) -> PeerEndpoint: + entrypoint = PeerAddress.from_address(self.addr1) + if self._fake_bootstrap_id is False: + return entrypoint.with_id(self.manager1.my_peer.id) + return entrypoint.with_id(self._fake_bootstrap_id) + @property def proto1(self): return self._proto1 @@ -234,10 +271,21 @@ def reconnect(self) -> None: self.disconnect(Failure(Exception('forced reconnection'))) self._buf1.clear() self._buf2.clear() - self._proto1 = self.manager1.connections.server_factory.buildProtocol(HostnameAddress(b'fake', 0)) - self._proto2 = self.manager2.connections.client_factory.buildProtocol(HostnameAddress(b'fake', 0)) - self.tr1 = HathorStringTransport(self._proto2.my_peer) - self.tr2 = HathorStringTransport(self._proto1.my_peer) + + self._proto1 = self.manager1.connections.server_factory.buildProtocol(self.addr2) + self._proto2 = self.manager2.connections.client_factory.buildProtocol(self.addr1) + + # When _fake_bootstrap_id is set we don't pass the peer because that's how bootstrap calls connect_to() + peer = self._proto1.my_peer.to_unverified_peer() if self._fake_bootstrap_id is False else None + self.manager2.connections.connect_to(self.entrypoint, peer) + + connecting_peers = list(self.manager2.connections.connecting_peers.values()) + for connecting_peer in connecting_peers: + if connecting_peer.entrypoint.addr == self.entrypoint.addr: + connecting_peer.endpoint_deferred.callback(self._proto2) + + self.tr1 = HathorStringTransport(self._proto2.my_peer, peer_address=self.addr2) + self.tr2 = HathorStringTransport(self._proto1.my_peer, peer_address=self.addr1) self._proto1.makeConnection(self.tr1) self._proto2.makeConnection(self.tr2) self.is_connected = True diff --git a/tests/others/test_metrics.py b/tests/others/test_metrics.py index bbdede763..b46f6985b 100644 --- a/tests/others/test_metrics.py +++ b/tests/others/test_metrics.py @@ -3,9 +3,9 @@ import pytest -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.manager import PeerConnectionsMetrics from hathor.p2p.peer import PrivatePeer +from hathor.p2p.peer_endpoint import PeerEndpoint from hathor.p2p.protocol import HathorProtocol from hathor.pubsub import HathorEvents from hathor.simulator.utils import add_new_blocks @@ -70,7 +70,7 @@ def test_connections_manager_integration(self): manager.connections.handshaking_peers.update({Mock()}) # Execution - endpoint = Entrypoint.parse('tcp://127.0.0.1:8005') + endpoint = PeerEndpoint.parse('tcp://127.0.0.1:8005') # This will trigger sending to the pubsub one of the network events manager.connections.connect_to(endpoint, use_ssl=True) diff --git a/tests/p2p/test_bootstrap.py b/tests/p2p/test_bootstrap.py index 3c3d9fa8c..82aa932bb 100644 --- a/tests/p2p/test_bootstrap.py +++ b/tests/p2p/test_bootstrap.py @@ -4,11 +4,11 @@ from twisted.names.dns import TXT, A, Record_A, Record_TXT, RRHeader from typing_extensions import override -from hathor.p2p.entrypoint import Entrypoint, Protocol from hathor.p2p.manager import ConnectionsManager from hathor.p2p.peer import PrivatePeer from hathor.p2p.peer_discovery import DNSPeerDiscovery, PeerDiscovery from hathor.p2p.peer_discovery.dns import LookupResult +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint, Protocol from hathor.pubsub import PubSubManager from tests import unittest from tests.test_memory_reactor_clock import TestMemoryReactorClock @@ -19,9 +19,10 @@ def __init__(self, mocked_host_ports: list[tuple[str, int]]): self.mocked_host_ports = mocked_host_ports @override - async def discover_and_connect(self, connect_to: Callable[[Entrypoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: for host, port in self.mocked_host_ports: - connect_to(Entrypoint(Protocol.TCP, host, port)) + addr = PeerAddress(Protocol.TCP, host, port) + connect_to(addr.with_id()) class MockDNSPeerDiscovery(DNSPeerDiscovery): diff --git a/tests/p2p/test_connections.py b/tests/p2p/test_connections.py index 570424c84..b27897ca4 100644 --- a/tests/p2p/test_connections.py +++ b/tests/p2p/test_connections.py @@ -1,4 +1,5 @@ -from hathor.p2p.entrypoint import Entrypoint +from hathor.manager import HathorManager +from hathor.p2p.peer_endpoint import PeerEndpoint from tests import unittest from tests.utils import run_server @@ -14,9 +15,9 @@ def test_connections(self) -> None: process3.terminate() def test_manager_connections(self) -> None: - manager = self.create_peer('testnet', enable_sync_v1=True, enable_sync_v2=False) + manager: HathorManager = self.create_peer('testnet', enable_sync_v1=True, enable_sync_v2=False) - endpoint = Entrypoint.parse('tcp://127.0.0.1:8005') + endpoint = PeerEndpoint.parse('tcp://127.0.0.1:8005') manager.connections.connect_to(endpoint, use_ssl=True) self.assertIn(endpoint, manager.connections.iter_not_ready_endpoints()) diff --git a/tests/p2p/test_peer_id.py b/tests/p2p/test_peer_id.py index 1604e29c9..56dfaf79b 100644 --- a/tests/p2p/test_peer_id.py +++ b/tests/p2p/test_peer_id.py @@ -4,10 +4,11 @@ from typing import cast from unittest.mock import Mock +import pytest from twisted.internet.interfaces import ITransport -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.peer import InvalidPeerIdException, PrivatePeer, PublicPeer, UnverifiedPeer +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.peer_id import PeerId from hathor.p2p.peer_storage import VerifiedPeerStorage from tests import unittest @@ -87,9 +88,9 @@ def test_merge_peer(self) -> None: self.assertEqual(peer.public_key, p1.public_key) self.assertEqual(peer.info.entrypoints, []) - ep1 = Entrypoint.parse('tcp://127.0.0.1:1001') - ep2 = Entrypoint.parse('tcp://127.0.0.1:1002') - ep3 = Entrypoint.parse('tcp://127.0.0.1:1003') + ep1 = PeerAddress.parse('tcp://127.0.0.1:1001') + ep2 = PeerAddress.parse('tcp://127.0.0.1:1002') + ep3 = PeerAddress.parse('tcp://127.0.0.1:1003') p3 = PrivatePeer.auto_generated().to_public_peer() p3.info.entrypoints.append(ep1) @@ -204,6 +205,59 @@ def test_retry_logic(self) -> None: peer.info.reset_retry_timestamp() self.assertTrue(peer.info.can_retry(0)) + def test_unverified_peer_to_json_roundtrip(self) -> None: + peer_id = PrivatePeer.auto_generated().id + addr1 = 'tcp://localhost:40403' + addr2 = 'tcp://192.168.0.1:40404' + addr3 = 'tcp://foo.bar:80' + + peer_json_simple = dict( + id=str(peer_id), + entrypoints=[addr1, addr2, addr3] + ) + result = UnverifiedPeer.create_from_json(peer_json_simple) + + assert result.id == peer_id + assert result.info.entrypoints == [ + PeerAddress.parse(addr1), + PeerAddress.parse(addr2), + PeerAddress.parse(addr3), + ] + assert result.to_json() == peer_json_simple + + # We support this for compatibility with old peers that may send ids in the URLs + peer_json_with_ids = dict( + id=str(peer_id), + entrypoints=[ + f'{addr1}/?id={peer_id}', + f'{addr2}/?id={peer_id}', + addr3, + ] + ) + result = UnverifiedPeer.create_from_json(peer_json_with_ids) + + assert result.id == peer_id + assert result.info.entrypoints == [ + PeerAddress.parse(addr1), + PeerAddress.parse(addr2), + PeerAddress.parse(addr3), + ] + assert result.to_json() == peer_json_simple # the roundtrip erases the ids from the URLs + + other_peer_id = PrivatePeer.auto_generated().id + peer_json_with_conflicting_ids = dict( + id=str(peer_id), + entrypoints=[ + f'{addr1}/?id={peer_id}', + f'{addr2}/?id={other_peer_id}', + addr3, + ] + ) + + with pytest.raises(ValueError) as e: + UnverifiedPeer.create_from_json(peer_json_with_conflicting_ids) + assert str(e.value) == f'conflicting peer_id: {other_peer_id} != {peer_id}' + class BasePeerIdTest(unittest.TestCase): __test__ = False @@ -211,25 +265,25 @@ class BasePeerIdTest(unittest.TestCase): async def test_validate_entrypoint(self) -> None: manager = self.create_peer('testnet', unlock_wallet=False) peer = manager.my_peer - peer.info.entrypoints = [Entrypoint.parse('tcp://127.0.0.1:40403')] + peer.info.entrypoints = [PeerAddress.parse('tcp://127.0.0.1:40403')] # we consider that we are starting the connection to the peer protocol = manager.connections.client_factory.buildProtocol('127.0.0.1') - protocol.entrypoint = Entrypoint.parse('tcp://127.0.0.1:40403') + protocol.entrypoint = PeerEndpoint.parse('tcp://127.0.0.1:40403') result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) # if entrypoint is an URI - peer.info.entrypoints = [Entrypoint.parse('tcp://uri_name:40403')] + peer.info.entrypoints = [PeerAddress.parse('tcp://uri_name:40403')] result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) # test invalid. DNS in test mode will resolve to '127.0.0.1:40403' - protocol.entrypoint = Entrypoint.parse('tcp://45.45.45.45:40403') + protocol.entrypoint = PeerEndpoint.parse('tcp://45.45.45.45:40403') result = await peer.info.validate_entrypoint(protocol) self.assertFalse(result) # now test when receiving the connection - i.e. the peer starts it protocol.entrypoint = None - peer.info.entrypoints = [Entrypoint.parse('tcp://127.0.0.1:40403')] + peer.info.entrypoints = [PeerAddress.parse('tcp://127.0.0.1:40403')] from collections import namedtuple DummyPeer = namedtuple('DummyPeer', 'host') @@ -241,7 +295,7 @@ def getPeer(self) -> DummyPeer: result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) # if entrypoint is an URI - peer.info.entrypoints = [Entrypoint.parse('tcp://uri_name:40403')] + peer.info.entrypoints = [PeerAddress.parse('tcp://uri_name:40403')] result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) diff --git a/tests/p2p/test_protocol.py b/tests/p2p/test_protocol.py index 34ec291d3..841a45929 100644 --- a/tests/p2p/test_protocol.py +++ b/tests/p2p/test_protocol.py @@ -1,12 +1,16 @@ -from json import JSONDecodeError +import json from typing import Optional from unittest.mock import Mock, patch +from twisted.internet import defer from twisted.internet.protocol import Protocol from twisted.python.failure import Failure -from hathor.p2p.entrypoint import Entrypoint +from hathor.manager import HathorManager +from hathor.p2p.manager import ConnectionsManager +from hathor.p2p.messages import ProtocolMessages from hathor.p2p.peer import PrivatePeer +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.protocol import HathorLineReceiver, HathorProtocol from hathor.simulator import FakeConnection from hathor.util import json_dumps, json_loadb @@ -72,7 +76,7 @@ def test_on_connect(self) -> None: def test_peer_with_entrypoint(self) -> None: entrypoint_str = 'tcp://192.168.1.1:54321' - entrypoint = Entrypoint.parse(entrypoint_str) + entrypoint = PeerAddress.parse(entrypoint_str) self.peer1.info.entrypoints.append(entrypoint) self.peer2.info.entrypoints.append(entrypoint) self.conn.run_one_step() # HELLO @@ -144,8 +148,10 @@ def test_invalid_payload(self) -> None: self.conn.run_one_step() # HELLO self.conn.run_one_step() # PEER-ID self.conn.run_one_step() # READY - with self.assertRaises(JSONDecodeError): - self._send_cmd(self.conn.proto1, 'PEERS', 'abc') + self.conn.tr1.clear() + self._send_cmd(self.conn.proto1, 'PEERS', 'abc') + assert self.conn.peek_tr1_value() == b'ERROR Error processing "PEERS" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) def test_invalid_hello1(self) -> None: self.conn.tr1.clear() @@ -263,6 +269,68 @@ def test_invalid_same_peer_id2(self) -> None: # connection is still up self.assertIsConnected(conn_alive) + def test_invalid_peer_id1(self) -> None: + """Test no payload""" + self.conn.run_one_step() + self.conn.tr1.clear() + self._send_cmd(self.conn.proto1, 'PEER-ID') + assert self.conn.peek_tr1_value() == b'ERROR Error processing "PEER-ID" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + + def test_invalid_peer_id2(self) -> None: + """Test invalid json payload""" + self.conn.run_one_step() + self.conn.tr1.clear() + self._send_cmd(self.conn.proto1, 'PEER-ID', 'invalid_payload') + assert self.conn.peek_tr1_value() == b'ERROR Error processing "PEER-ID" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + + def test_invalid_peer_id3(self) -> None: + """Test empty payload""" + self.conn.run_one_step() + self.conn.tr1.clear() + self._send_cmd(self.conn.proto1, 'PEER-ID', '{}') + assert self.conn.peek_tr1_value() == b'ERROR Error processing "PEER-ID" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + + def test_invalid_peer_id4(self) -> None: + """Test payload with missing property""" + self.conn.run_one_step() + self.conn.tr1.clear() + data = self.conn.proto2.state._get_peer_id_data() + del data['pubKey'] + self._send_cmd( + self.conn.proto1, + 'PEER-ID', + json.dumps(data) + ) + assert self.conn.peek_tr1_value() == b'ERROR Error processing "PEER-ID" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + self.assertTrue(self.conn.tr1.disconnecting) + + def test_invalid_peer_id5(self) -> None: + """Test payload with peer id not matching public key""" + self.conn.run_one_step() + self.conn.tr1.clear() + data = self.conn.proto2.state._get_peer_id_data() + new_peer = PrivatePeer.auto_generated() + data['id'] = str(new_peer.id) + self._send_cmd( + self.conn.proto1, + 'PEER-ID', + json.dumps(data) + ) + assert self.conn.peek_tr1_value() == b'ERROR Error processing "PEER-ID" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + + def test_valid_peer_id(self) -> None: + self.conn.run_one_step() + self.conn.run_one_step() + self._check_result_only_cmd(self.conn.peek_tr1_value(), b'READY') + self._check_result_only_cmd(self.conn.peek_tr2_value(), b'READY') + self.assertFalse(self.conn.tr1.disconnecting) + self.assertFalse(self.conn.tr2.disconnecting) + def test_invalid_different_network(self) -> None: manager3 = self.create_peer(network='mainnet') conn = FakeConnection(self.manager1, manager3) @@ -314,6 +382,129 @@ def test_idle_connection(self) -> None: self.clock.advance(15) self.assertIsNotConnected(self.conn) + def test_invalid_expected_peer_id(self) -> None: + p2p_manager: ConnectionsManager = self.manager2.connections + + # Initially, manager1 and manager2 are handshaking, from the setup + assert p2p_manager.connecting_peers == {} + assert p2p_manager.handshaking_peers == {self.conn.proto2} + assert p2p_manager.connected_peers == {} + + # We change our peer id (on manager1) + new_peer = PrivatePeer.auto_generated() + self.conn.proto1.my_peer = new_peer + self.conn.tr2._peer = new_peer + + # We advance the states and fail in the PEER-ID step (on manager2) + self._check_result_only_cmd(self.conn.peek_tr2_value(), b'HELLO') + self.conn.run_one_step() + self._check_result_only_cmd(self.conn.peek_tr2_value(), b'PEER-ID') + self.conn.run_one_step() + assert self.conn.peek_tr2_value() == b'ERROR Peer id different from the requested one.\r\n' + + def test_invalid_expected_peer_id_bootstrap(self) -> None: + p2p_manager: ConnectionsManager = self.manager1.connections + + # Initially, manager1 and manager2 are handshaking, from the setup + assert p2p_manager.connecting_peers == {} + assert p2p_manager.handshaking_peers == {self.conn.proto1} + assert p2p_manager.connected_peers == {} + + # We create a new manager3, and use it as a bootstrap in manager1 + peer3 = PrivatePeer.auto_generated() + manager3: HathorManager = self.create_peer(self.network, peer3) + conn = FakeConnection(manager1=manager3, manager2=self.manager1, fake_bootstrap_id=peer3.id) + + # Now manager1 and manager3 are handshaking + assert p2p_manager.connecting_peers == {} + assert p2p_manager.handshaking_peers == {self.conn.proto1, conn.proto2} + assert p2p_manager.connected_peers == {} + + # We change our peer id (on manager3) + new_peer = PrivatePeer.auto_generated() + conn.proto1.my_peer = new_peer + conn.tr2._peer = new_peer + + # We advance the states and fail in the PEER-ID step (on manager1) + self._check_result_only_cmd(conn.peek_tr2_value(), b'HELLO') + conn.run_one_step() + self._check_result_only_cmd(conn.peek_tr2_value(), b'PEER-ID') + conn.run_one_step() + assert conn.peek_tr2_value() == b'ERROR Peer id different from the requested one.\r\n' + + def test_valid_unset_peer_id_bootstrap(self) -> None: + p2p_manager: ConnectionsManager = self.manager1.connections + + # Initially, manager1 and manager2 are handshaking, from the setup + assert p2p_manager.connecting_peers == {} + assert p2p_manager.handshaking_peers == {self.conn.proto1} + assert p2p_manager.connected_peers == {} + + # We create a new manager3, and use it as a bootstrap in manager1, but without the peer_id + manager3: HathorManager = self.create_peer(self.network) + conn = FakeConnection(manager1=manager3, manager2=self.manager1, fake_bootstrap_id=None) + + # Now manager1 and manager3 are handshaking + assert p2p_manager.connecting_peers == {} + assert p2p_manager.handshaking_peers == {self.conn.proto1, conn.proto2} + assert p2p_manager.connected_peers == {} + + # We change our peer id (on manager3) + new_peer = PrivatePeer.auto_generated() + conn.proto1.my_peer = new_peer + conn.tr2._peer = new_peer + + # We advance the states and in this case succeed (on manager1), because + # even though the peer_id was changed, it wasn't initially set. + self._check_result_only_cmd(conn.peek_tr2_value(), b'HELLO') + conn.run_one_step() + self._check_result_only_cmd(conn.peek_tr2_value(), b'PEER-ID') + conn.run_one_step() + self._check_result_only_cmd(conn.peek_tr2_value(), b'READY') + + def test_exception_on_synchronous_cmd_handler(self) -> None: + self.conn.run_one_step() + self.conn.run_one_step() + + def error() -> None: + raise Exception('some error') + + self.conn.proto1.state.cmd_map = { + ProtocolMessages.READY: error + } + + self.conn.run_one_step() + assert self.conn.peek_tr1_value() == b'ERROR Error processing "READY" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + + def test_exception_on_deferred_cmd_handler(self) -> None: + self.conn.run_one_step() + self.conn.run_one_step() + + self.conn.proto1.state.cmd_map = { + ProtocolMessages.READY: lambda: defer.fail(Exception('some error')), + } + + self.conn.run_one_step() + assert self.conn.peek_tr1_value() == b'ERROR Error processing "READY" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + + def test_exception_on_asynchronous_cmd_handler(self) -> None: + self.conn.run_one_step() + self.conn.run_one_step() + + async def error() -> None: + raise Exception('some error') + + self.conn.proto1.state.cmd_map = { + ProtocolMessages.READY: error + } + + self.conn.run_one_step() + self.clock.advance(1) + assert self.conn.peek_tr1_value() == b'ERROR Error processing "READY" command\r\n' + self.assertTrue(self.conn.tr1.disconnecting) + class SyncV1HathorProtocolTestCase(unittest.SyncV1Params, BaseHathorProtocolTestCase): __test__ = True diff --git a/tests/resources/p2p/test_add_peer.py b/tests/resources/p2p/test_add_peer.py index ca9ca99a2..f70f3aefe 100644 --- a/tests/resources/p2p/test_add_peer.py +++ b/tests/resources/p2p/test_add_peer.py @@ -1,7 +1,7 @@ from twisted.internet.defer import inlineCallbacks -from hathor.p2p.entrypoint import Entrypoint from hathor.p2p.peer import PrivatePeer +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.resources import AddPeersResource from tests import unittest from tests.resources.base_resource import StubSite, _BaseResourceTest @@ -22,7 +22,7 @@ def test_connecting_peers(self): # test when we send a peer we're already connected to peer = PrivatePeer.auto_generated() - peer.entrypoints = [Entrypoint.parse('tcp://localhost:8006')] + peer.entrypoints = [PeerAddress.parse('tcp://localhost:8006')] self.manager.connections.verified_peer_storage.add(peer) response = yield self.web.post('p2p/peers', ['tcp://localhost:8006', 'tcp://localhost:8007']) data = response.json_value() diff --git a/tests/resources/p2p/test_status.py b/tests/resources/p2p/test_status.py index 68d409348..646ba6903 100644 --- a/tests/resources/p2p/test_status.py +++ b/tests/resources/p2p/test_status.py @@ -1,9 +1,10 @@ from twisted.internet import endpoints +from twisted.internet.address import IPv4Address from twisted.internet.defer import inlineCallbacks import hathor from hathor.conf.unittests import SETTINGS -from hathor.p2p.entrypoint import Entrypoint +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.resources import StatusResource from hathor.simulator import FakeConnection from tests import unittest @@ -16,14 +17,15 @@ class BaseStatusTest(_BaseResourceTest._ResourceTest): def setUp(self): super().setUp() self.web = StubSite(StatusResource(self.manager)) - self.entrypoint = Entrypoint.parse('tcp://192.168.1.1:54321') - self.manager.connections.my_peer.info.entrypoints.append(self.entrypoint) + address1 = IPv4Address('TCP', '192.168.1.1', 54321) + self.manager.connections.my_peer.info.entrypoints.append(PeerAddress.from_address(address1)) self.manager.peers_whitelist.append(self.get_random_peer_from_pool().id) self.manager.peers_whitelist.append(self.get_random_peer_from_pool().id) self.manager2 = self.create_peer('testnet') - self.manager2.connections.my_peer.info.entrypoints.append(self.entrypoint) - self.conn1 = FakeConnection(self.manager, self.manager2) + address2 = IPv4Address('TCP', '192.168.1.1', 54322) + self.manager2.connections.my_peer.info.entrypoints.append(PeerAddress.from_address(address2)) + self.conn1 = FakeConnection(self.manager, self.manager2, addr1=address1, addr2=address2) @inlineCallbacks def test_get(self):