Skip to content

feat(p2p): add ability to update peer_id.json with SIGUSR1 #981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions hathor/builder/cli_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import getpass
import json
import os
import platform
import sys
Expand Down Expand Up @@ -96,8 +95,7 @@ def create_manager(self, reactor: Reactor) -> HathorManager:
self.log = logger.new()
self.reactor = reactor

peer_id = self.create_peer_id()

peer_id = PeerId.create_from_json_path(self._args.peer) if self._args.peer else PeerId()
python = f'{platform.python_version()}-{platform.python_implementation()}'

self.log.info(
Expand Down Expand Up @@ -367,7 +365,7 @@ def create_manager(self, reactor: Reactor) -> HathorManager:
self.log.warn('--memory-indexes is implied for memory storage or JSON storage')

for description in self._args.listen:
p2p_manager.add_listen_address(description)
p2p_manager.add_listen_address_description(description)

if self._args.peer_id_blacklist:
self.log.info('with peer id blacklist', blacklist=self._args.peer_id_blacklist)
Expand Down Expand Up @@ -397,14 +395,6 @@ def get_hostname(self) -> Optional[str]:
print('Hostname discovered and set to {}'.format(hostname))
return hostname

def create_peer_id(self) -> PeerId:
if not self._args.peer:
peer_id = PeerId()
else:
data = json.load(open(self._args.peer, 'r'))
peer_id = PeerId.create_from_json(data)
return peer_id

def create_wallet(self) -> BaseWallet:
if self._args.wallet == 'hd':
kwargs: dict[str, Any] = {
Expand Down
5 changes: 2 additions & 3 deletions hathor/cli/run_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,8 @@ def register_signal_handlers(self) -> None:
def signal_usr1_handler(self, sig: int, frame: Any) -> None:
"""Called when USR1 signal is received."""
try:
self.log.warn('USR1 received. Killing all connections...')
if self.manager and self.manager.connections:
self.manager.connections.disconnect_all_peers(force=True)
self.log.warn('USR1 received.')
self.manager.connections.reload_entrypoints_and_connections()
except Exception:
# see: https://docs.python.org/3/library/signal.html#note-on-signal-handlers-and-exceptions
self.log.error('prevented exception from escaping the signal handler', exc_info=True)
Expand Down
4 changes: 4 additions & 0 deletions hathor/conf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ def GENESIS_TX2_TIMESTAMP(self) -> int:
OLD_MAX_MERKLE_PATH_LENGTH: int = 12
NEW_MAX_MERKLE_PATH_LENGTH: int = 20

# Maximum number of tx tips to accept in the initial phase of the mempool sync 1000 is arbitrary, but it should be
# more than enough for the forseeable future
MAX_MEMPOOL_RECEIVING_TIPS: int = 1000

# Used to enable nano contracts.
#
# This should NEVER be enabled for mainnet and testnet, since both networks will
Expand Down
59 changes: 34 additions & 25 deletions hathor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,33 @@ class UnhealthinessReason(str, Enum):
# This is the interval to be used by the task to check if the node is synced
CHECK_SYNC_STATE_INTERVAL = 30 # seconds

def __init__(self,
reactor: Reactor,
*,
settings: HathorSettings,
pubsub: PubSubManager,
consensus_algorithm: ConsensusAlgorithm,
daa: DifficultyAdjustmentAlgorithm,
peer_id: PeerId,
tx_storage: TransactionStorage,
p2p_manager: ConnectionsManager,
event_manager: EventManager,
feature_service: FeatureService,
bit_signaling_service: BitSignalingService,
verification_service: VerificationService,
cpu_mining_service: CpuMiningService,
network: str,
execution_manager: ExecutionManager,
hostname: Optional[str] = None,
wallet: Optional[BaseWallet] = None,
capabilities: Optional[list[str]] = None,
checkpoints: Optional[list[Checkpoint]] = None,
rng: Optional[Random] = None,
environment_info: Optional[EnvironmentInfo] = None,
full_verification: bool = False,
enable_event_queue: bool = False):
def __init__(
self,
reactor: Reactor,
*,
settings: HathorSettings,
pubsub: PubSubManager,
consensus_algorithm: ConsensusAlgorithm,
daa: DifficultyAdjustmentAlgorithm,
peer_id: PeerId,
tx_storage: TransactionStorage,
p2p_manager: ConnectionsManager,
event_manager: EventManager,
feature_service: FeatureService,
bit_signaling_service: BitSignalingService,
verification_service: VerificationService,
cpu_mining_service: CpuMiningService,
network: str,
execution_manager: ExecutionManager,
hostname: Optional[str] = None,
wallet: Optional[BaseWallet] = None,
capabilities: Optional[list[str]] = None,
checkpoints: Optional[list[Checkpoint]] = None,
rng: Optional[Random] = None,
environment_info: Optional[EnvironmentInfo] = None,
full_verification: bool = False,
enable_event_queue: bool = False,
) -> None:
"""
:param reactor: Twisted reactor which handles the mainloop and the events.
:param peer_id: Id of this node.
Expand Down Expand Up @@ -1173,6 +1175,13 @@ def get_cmd_path(self) -> Optional[str]:
"""Return the cmd path. If no cmd path is set, returns None."""
return self._cmd_path

def set_hostname_and_reset_connections(self, new_hostname: str) -> None:
"""Set the hostname and reset all connections."""
old_hostname = self.hostname
self.hostname = new_hostname
self.connections.update_hostname_entrypoints(old_hostname=old_hostname, new_hostname=self.hostname)
self.connections.disconnect_all_peers(force=True)


class ParentTxs(NamedTuple):
""" Tuple where the `must_include` hash, when present (at most 1), must be included in a pair, and a list of hashes
Expand Down
69 changes: 51 additions & 18 deletions hathor/p2p/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from structlog import get_logger
from twisted.internet import endpoints
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IProtocolFactory, IStreamClientEndpoint, IStreamServerEndpoint
from twisted.internet.interfaces import IListeningPort, IProtocolFactory, IStreamClientEndpoint
from twisted.internet.task import LoopingCall
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.python.failure import Failure
Expand Down Expand Up @@ -108,8 +109,11 @@ def __init__(self,

self.network = network

# List of addresses to listen for new connections (eg: [tcp:8000])
self.listen_addresses: list[str] = []
# List of address descriptions to listen for new connections (eg: [tcp:8000])
self.listen_address_descriptions: list[str] = []

# List of actual IP address instances to listen for new connections
self._listen_addresses: list[IPv4Address | IPv6Address] = []

# List of peer discovery methods.
self.peer_discoveries: list[PeerDiscovery] = []
Expand Down Expand Up @@ -239,9 +243,9 @@ def set_manager(self, manager: 'HathorManager') -> None:
self.log.debug('enable sync-v2 indexes')
indexes.enable_mempool_index()

def add_listen_address(self, addr: str) -> None:
def add_listen_address_description(self, addr: str) -> None:
"""Add address to listen for incoming connections."""
self.listen_addresses.append(addr)
self.listen_address_descriptions.append(addr)

def add_peer_discovery(self, peer_discovery: PeerDiscovery) -> None:
"""Add a peer discovery method."""
Expand Down Expand Up @@ -279,7 +283,7 @@ def start(self) -> None:
if self._settings.ENABLE_PEER_WHITELIST:
self._start_whitelist_reconnect()

for description in self.listen_addresses:
for description in self.listen_address_descriptions:
self.listen(description)

self.do_discovery()
Expand Down Expand Up @@ -635,7 +639,7 @@ def connect_to(self, description: str, peer: Optional[PeerId] = None, use_ssl: O
peers_count=self._get_peers_count()
)

def listen(self, description: str, use_ssl: Optional[bool] = None) -> IStreamServerEndpoint:
def listen(self, description: str, use_ssl: Optional[bool] = None) -> None:
""" Start to listen for new connection according to the description.

If `ssl` is True, then the connection will be wraped by a TLS.
Expand All @@ -661,20 +665,43 @@ def listen(self, description: str, use_ssl: Optional[bool] = None) -> IStreamSer

factory = NetfilterFactory(self, factory)

self.log.info('listen on', endpoint=description)
endpoint.listen(factory)
self.log.info('trying to listen on', endpoint=description)
deferred: Deferred[IListeningPort] = endpoint.listen(factory)
deferred.addCallback(self._on_listen_success, description)

def _on_listen_success(self, listening_port: IListeningPort, description: str) -> None:
"""Callback to be called when listening to an endpoint succeeds."""
self.log.info('success listening on', endpoint=description)
address = listening_port.getHost()

if not isinstance(address, (IPv4Address, IPv6Address)):
self.log.error(f'unhandled address type for endpoint "{description}": {str(type(address))}')
return

self._listen_addresses.append(address)

# XXX: endpoint: IStreamServerEndpoint does not intrinsically have a port, but in practice all concrete cases
# that we have will have a _port attribute
port = getattr(endpoint, '_port', None)
assert self.manager is not None
if self.manager.hostname and port is not None:
proto, _, _ = description.partition(':')
address = '{}://{}:{}'.format(proto, self.manager.hostname, port)
assert self.manager.my_peer is not None
self.manager.my_peer.entrypoints.append(address)
if self.manager.hostname:
self._add_hostname_entrypoint(self.manager.hostname, address)

return endpoint
def update_hostname_entrypoints(self, *, old_hostname: str | None, new_hostname: str) -> None:
"""Add new hostname entrypoints according to the listen addresses, and remove any old entrypoint."""
assert self.manager is not None
for address in self._listen_addresses:
if old_hostname is not None:
old_address_str = self._get_hostname_address_str(old_hostname, address)
if old_address_str in self.my_peer.entrypoints:
self.my_peer.entrypoints.remove(old_address_str)

self._add_hostname_entrypoint(new_hostname, address)

def _add_hostname_entrypoint(self, hostname: str, address: IPv4Address | IPv6Address) -> None:
hostname_address_str = self._get_hostname_address_str(hostname, address)
self.my_peer.entrypoints.append(hostname_address_str)

@staticmethod
def _get_hostname_address_str(hostname: str, address: IPv4Address | IPv6Address) -> str:
return '{}://{}:{}'.format(address.type, hostname, address.port).lower()

def get_connection_to_drop(self, protocol: HathorProtocol) -> HathorProtocol:
""" When there are duplicate connections, determine which one should be dropped.
Expand Down Expand Up @@ -796,3 +823,9 @@ def _sync_rotate_if_needed(self, *, force: bool = False) -> None:

for peer_id in info.to_enable:
self.connected_peers[peer_id].enable_sync()

def reload_entrypoints_and_connections(self) -> None:
"""Kill all connections and reload entrypoints from the original peer config file."""
self.log.warn('Killing all connections and resetting entrypoints...')
self.disconnect_all_peers(force=True)
self.my_peer.reload_entrypoints_from_source_file()
31 changes: 30 additions & 1 deletion hathor/p2p/peer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import base64
import hashlib
import json
from enum import Enum
from math import inf
from typing import TYPE_CHECKING, Any, Optional, cast
Expand All @@ -24,6 +25,7 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from OpenSSL.crypto import X509, PKey
from structlog import get_logger
from twisted.internet.interfaces import ISSLTransport
from twisted.internet.ssl import Certificate, CertificateOptions, TLSVersion, trustRootFromCertificates

Expand All @@ -35,6 +37,8 @@
if TYPE_CHECKING:
from hathor.p2p.protocol import HathorProtocol # noqa: F401

logger = get_logger()


class InvalidPeerIdException(Exception):
pass
Expand Down Expand Up @@ -64,8 +68,10 @@ class PeerId:
retry_attempts: int # how many retries were made
last_seen: float # last time this peer was seen
flags: set[str]
source_file: str | None

def __init__(self, auto_generate_keys: bool = True) -> None:
self._log = logger.new()
self._settings = get_global_settings()
self.id = None
self.private_key = None
Expand Down Expand Up @@ -159,9 +165,15 @@ def verify_signature(self, signature: bytes, data: bytes) -> bool:
else:
return True

@classmethod
def create_from_json_path(cls, path: str) -> 'PeerId':
"""Create a new PeerId from a JSON file."""
data = json.load(open(path, 'r'))
return PeerId.create_from_json(data)

@classmethod
def create_from_json(cls, data: dict[str, Any]) -> 'PeerId':
""" Create a new PeerId from a JSON.
""" Create a new PeerId from JSON data.

It is used both to load a PeerId from disk and to create a PeerId
from a peer connection.
Expand Down Expand Up @@ -408,3 +420,20 @@ def validate_certificate(self, protocol: 'HathorProtocol') -> bool:
return False

return True

def reload_entrypoints_from_source_file(self) -> None:
"""Update this PeerId's entrypoints from the json file."""
if not self.source_file:
raise Exception('Trying to reload entrypoints but no peer config file was provided.')

new_peer_id = PeerId.create_from_json_path(self.source_file)

if new_peer_id.id != self.id:
self._log.error(
'Ignoring peer id file update because the peer_id does not match.',
current_peer_id=self.id,
new_peer_id=new_peer_id.id,
)
return

self.entrypoints = new_peer_id.entrypoints
11 changes: 9 additions & 2 deletions hathor/p2p/sync_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def __init__(self, protocol: 'HathorProtocol', reactor: Reactor) -> None:
# Saves if I am in the middle of a mempool sync
# we don't execute any sync while in the middle of it
self.mempool_manager = SyncMempoolManager(self)
self._receiving_tips: Optional[list[bytes]] = None
self._receiving_tips: Optional[list[VertexId]] = None
self.max_receiving_tips: int = self._settings.MAX_MEMPOOL_RECEIVING_TIPS

# Cache for get_tx calls
self._get_tx_cache: OrderedDict[bytes, BaseTransaction] = OrderedDict()
Expand Down Expand Up @@ -476,7 +477,13 @@ def handle_tips(self, payload: str) -> None:
data = json.loads(payload)
data = [bytes.fromhex(x) for x in data]
# filter-out txs we already have
self._receiving_tips.extend(tx_id for tx_id in data if not self.partial_vertex_exists(tx_id))
try:
self._receiving_tips.extend(VertexId(tx_id) for tx_id in data if not self.partial_vertex_exists(tx_id))
except ValueError:
self.protocol.send_error_and_close_connection('Invalid trasaction ID received')
# XXX: it's OK to do this *after* the extend because the payload is limited by the line protocol
if len(self._receiving_tips) > self.max_receiving_tips:
self.protocol.send_error_and_close_connection(f'Too many tips: {len(self._receiving_tips)}')

def handle_tips_end(self, _payload: str) -> None:
""" Handle a TIPS-END message.
Expand Down
12 changes: 6 additions & 6 deletions hathor/p2p/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@
from hathor.transaction.genesis import get_representation_for_all_genesis


def discover_hostname() -> Optional[str]:
""" Try to discover your hostname. It is a synchonous operation and
def discover_hostname(timeout: float | None = None) -> Optional[str]:
""" Try to discover your hostname. It is a synchronous operation and
should not be called from twisted main loop.
"""
return discover_ip_ipify()
return discover_ip_ipify(timeout)


def discover_ip_ipify() -> Optional[str]:
def discover_ip_ipify(timeout: float | None = None) -> Optional[str]:
""" Try to discover your IP address using ipify's api.
It is a synchonous operation and should not be called from twisted main loop.
It is a synchronous operation and should not be called from twisted main loop.
"""
response = requests.get('https://api.ipify.org')
response = requests.get('https://api.ipify.org', timeout=timeout)
if response.ok:
# It may be either an ipv4 or ipv6 in string format.
ip = response.text
Expand Down
Loading
Loading