Skip to content

refactor(mypy): add stricter rules to unittest and utils [part V/VI] #974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ module = [
"tests.p2p.*",
"tests.pubsub.*",
"tests.simulation.*",
"tests.unittest",
"tests.utils",
]
disallow_untyped_defs = true

Expand Down
3 changes: 2 additions & 1 deletion tests/p2p/test_double_spending.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hathor.manager import HathorManager
from hathor.simulator.utils import add_new_blocks
from hathor.transaction import Transaction
from hathor.util import not_none
from tests import unittest
from tests.utils import add_blocks_unlock_reward, add_new_tx

Expand All @@ -23,7 +24,7 @@ def setUp(self) -> None:
def _add_new_transactions(self, manager: HathorManager, num_txs: int) -> list[Transaction]:
txs = []
for _ in range(num_txs):
address = self.get_address(0)
address = not_none(self.get_address(0))
value = self.rng.choice([5, 10, 15, 20])
tx = add_new_tx(manager, address, value)
txs.append(tx)
Expand Down
2 changes: 1 addition & 1 deletion tests/tx/test_indexes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_timestamp_index(self):
# XXX: we verified they're the same, doesn't matter which we pick:
idx = idx_memory
hashes = hashes_memory
self.log.debug('indexes match', idx=idx, hashes=unittest.shorten_hash(hashes))
self.log.debug('indexes match', idx=idx, hashes=unittest.short_hashes(hashes))
if idx is None:
break
offset_variety.add(idx[1])
Expand Down
112 changes: 70 additions & 42 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shutil
import tempfile
import time
from typing import Iterator, Optional
from typing import Callable, Collection, Iterable, Iterator, Optional
from unittest import main as ut_main

from structlog import get_logger
Expand All @@ -16,13 +16,17 @@
from hathor.daa import DifficultyAdjustmentAlgorithm, TestMode
from hathor.event import EventManager
from hathor.event.storage import EventStorage
from hathor.manager import HathorManager
from hathor.p2p.peer_id import PeerId
from hathor.p2p.sync_v1.agent import NodeSyncTimestamp
from hathor.p2p.sync_v2.agent import NodeBlockSync
from hathor.p2p.sync_version import SyncVersion
from hathor.pubsub import PubSubManager
from hathor.reactor import ReactorProtocol as Reactor, get_global_reactor
from hathor.simulator.clock import MemoryReactorHeapClock
from hathor.transaction import BaseTransaction
from hathor.transaction import BaseTransaction, Block, Transaction
from hathor.transaction.storage.transaction_storage import TransactionStorage
from hathor.types import VertexId
from hathor.util import Random, not_none
from hathor.wallet import BaseWallet, HDWallet, Wallet
from tests.test_memory_reactor_clock import TestMemoryReactorClock
Expand All @@ -33,9 +37,8 @@
USE_MEMORY_STORAGE = os.environ.get('HATHOR_TEST_MEMORY_STORAGE', 'false').lower() == 'true'


def shorten_hash(container):
container_type = type(container)
return container_type(h[-2:].hex() for h in container)
def short_hashes(container: Collection[bytes]) -> Iterable[str]:
return map(lambda hash_bytes: hash_bytes[-2:].hex(), container)


def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]:
Expand All @@ -50,7 +53,7 @@ def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]:
yield PeerId.create_from_json(peer_id_dict)


def _get_default_peer_id_pool_filepath():
def _get_default_peer_id_pool_filepath() -> str:
this_file_path = os.path.dirname(__file__)
file_name = 'peer_id_pool.json'
file_path = os.path.join(this_file_path, file_name)
Expand Down Expand Up @@ -109,19 +112,19 @@ class TestCase(unittest.TestCase):
use_memory_storage: bool = USE_MEMORY_STORAGE
seed_config: Optional[int] = None

def setUp(self):
self.tmpdirs = []
def setUp(self) -> None:
self.tmpdirs: list[str] = []
self.clock = TestMemoryReactorClock()
self.clock.advance(time.time())
self.log = logger.new()
self.reset_peer_id_pool()
self.seed = secrets.randbits(64) if self.seed_config is None else self.seed_config
self.log.info('set seed', seed=self.seed)
self.rng = Random(self.seed)
self._pending_cleanups = []
self._pending_cleanups: list[Callable] = []
self._settings = get_global_settings()

def tearDown(self):
def tearDown(self) -> None:
self.clean_tmpdirs()
for fn in self._pending_cleanups:
fn()
Expand All @@ -144,12 +147,12 @@ def get_random_peer_id_from_pool(self, pool: Optional[list[PeerId]] = None,
pool.remove(peer_id)
return peer_id

def mkdtemp(self):
def mkdtemp(self) -> str:
tmpdir = tempfile.mkdtemp()
self.tmpdirs.append(tmpdir)
return tmpdir

def _create_test_wallet(self, unlocked=False):
def _create_test_wallet(self, unlocked: bool = False) -> Wallet:
""" Generate a Wallet with a number of keypairs for testing
:rtype: Wallet
"""
Expand All @@ -169,14 +172,14 @@ def get_builder(self, network: str) -> TestBuilder:
.set_network(network)
return builder

def create_peer_from_builder(self, builder, start_manager=True):
def create_peer_from_builder(self, builder: Builder, start_manager: bool = True) -> HathorManager:
artifacts = builder.build()
manager = artifacts.manager

if artifacts.rocksdb_storage:
self._pending_cleanups.append(artifacts.rocksdb_storage.close)

manager.avg_time_between_blocks = 0.0001
# manager.avg_time_between_blocks = 0.0001 # FIXME: This property is not defined. Fix this.

if start_manager:
manager.start()
Expand Down Expand Up @@ -277,7 +280,7 @@ def create_peer( # type: ignore[no-untyped-def]

return manager

def run_to_completion(self):
def run_to_completion(self) -> None:
""" This will advance the test's clock until all calls scheduled are done.
"""
for call in self.clock.getDelayedCalls():
Expand All @@ -300,7 +303,11 @@ def assertIsTopological(self, tx_sequence: Iterator[BaseTransaction], message: O
self.assertIn(dep, valid_deps, message)
valid_deps.add(tx.hash)

def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None):
def _syncVersionFlags(
self,
enable_sync_v1: bool | None = None,
enable_sync_v2: bool | None = None
) -> tuple[bool, bool]:
"""Internal: use this to check and get the flags and optionally provide override values."""
if enable_sync_v1 is None:
assert hasattr(self, '_enable_sync_v1'), ('`_enable_sync_v1` has no default by design, either set one on '
Expand All @@ -313,19 +320,19 @@ def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None):
assert enable_sync_v1 or enable_sync_v2, 'enable at least one sync version'
return enable_sync_v1, enable_sync_v2

def assertTipsEqual(self, manager1, manager2):
def assertTipsEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
_, enable_sync_v2 = self._syncVersionFlags()
if enable_sync_v2:
self.assertTipsEqualSyncV2(manager1, manager2)
else:
self.assertTipsEqualSyncV1(manager1, manager2)

def assertTipsNotEqual(self, manager1, manager2):
def assertTipsNotEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
s1 = set(manager1.tx_storage.get_all_tips())
s2 = set(manager2.tx_storage.get_all_tips())
self.assertNotEqual(s1, s2)

def assertTipsEqualSyncV1(self, manager1, manager2):
def assertTipsEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None:
# XXX: this is the original implementation of assertTipsEqual
s1 = set(manager1.tx_storage.get_all_tips())
s2 = set(manager2.tx_storage.get_all_tips())
Expand All @@ -335,39 +342,45 @@ def assertTipsEqualSyncV1(self, manager1, manager2):
s2 = set(manager2.tx_storage.get_tx_tips())
self.assertEqual(s1, s2)

def assertTipsEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True):
def assertTipsEqualSyncV2(
self,
manager1: HathorManager,
manager2: HathorManager,
*,
strict_sync_v2_indexes: bool = True
) -> None:
# tx tips
if strict_sync_v2_indexes:
tips1 = manager1.tx_storage.indexes.mempool_tips.get()
tips2 = manager2.tx_storage.indexes.mempool_tips.get()
tips1 = not_none(not_none(manager1.tx_storage.indexes).mempool_tips).get()
tips2 = not_none(not_none(manager2.tx_storage.indexes).mempool_tips).get()
else:
tips1 = {tx.hash for tx in manager1.tx_storage.iter_mempool_tips_from_best_index()}
tips2 = {tx.hash for tx in manager2.tx_storage.iter_mempool_tips_from_best_index()}
self.log.debug('tx tips1', len=len(tips1), list=shorten_hash(tips1))
self.log.debug('tx tips2', len=len(tips2), list=shorten_hash(tips2))
self.log.debug('tx tips1', len=len(tips1), list=short_hashes(tips1))
self.log.debug('tx tips2', len=len(tips2), list=short_hashes(tips2))
self.assertEqual(tips1, tips2)

# best block
s1 = set(manager1.tx_storage.get_best_block_tips())
s2 = set(manager2.tx_storage.get_best_block_tips())
self.log.debug('block tips1', len=len(s1), list=shorten_hash(s1))
self.log.debug('block tips2', len=len(s2), list=shorten_hash(s2))
self.log.debug('block tips1', len=len(s1), list=short_hashes(s1))
self.log.debug('block tips2', len=len(s2), list=short_hashes(s2))
self.assertEqual(s1, s2)

# best block (from height index)
b1 = manager1.tx_storage.indexes.height.get_tip()
b2 = manager2.tx_storage.indexes.height.get_tip()
b1 = not_none(manager1.tx_storage.indexes).height.get_tip()
b2 = not_none(manager2.tx_storage.indexes).height.get_tip()
self.assertIn(b1, s2)
self.assertIn(b2, s1)

def assertConsensusEqual(self, manager1, manager2):
def assertConsensusEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
_, enable_sync_v2 = self._syncVersionFlags()
if enable_sync_v2:
self.assertConsensusEqualSyncV2(manager1, manager2)
else:
self.assertConsensusEqualSyncV1(manager1, manager2)

def assertConsensusEqualSyncV1(self, manager1, manager2):
def assertConsensusEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None:
self.assertEqual(manager1.tx_storage.get_vertices_count(), manager2.tx_storage.get_vertices_count())
for tx1 in manager1.tx_storage.get_all_transactions():
tx2 = manager2.tx_storage.get_transaction(tx1.hash)
Expand All @@ -381,12 +394,20 @@ def assertConsensusEqualSyncV1(self, manager1, manager2):
self.assertIsNone(tx2_meta.voided_by)
else:
# If tx1 is voided, then tx2 must be voided.
assert tx1_meta.voided_by is not None
assert tx2_meta.voided_by is not None
self.assertGreaterEqual(len(tx1_meta.voided_by), 1)
self.assertGreaterEqual(len(tx2_meta.voided_by), 1)
# Hard verification
# self.assertEqual(tx1_meta.voided_by, tx2_meta.voided_by)

def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True):
def assertConsensusEqualSyncV2(
self,
manager1: HathorManager,
manager2: HathorManager,
*,
strict_sync_v2_indexes: bool = True
) -> None:
# The current sync algorithm does not propagate voided blocks/txs
# so the count might be different even though the consensus is equal
# One peer might have voided txs that the other does not have
Expand All @@ -397,7 +418,9 @@ def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_index
# the following is specific to sync-v2

# helper function:
def get_all_executed_or_voided(tx_storage):
def get_all_executed_or_voided(
tx_storage: TransactionStorage
) -> tuple[set[VertexId], set[VertexId], set[VertexId]]:
"""Get all txs separated into three sets: executed, voided, partial"""
tx_executed = set()
tx_voided = set()
Expand All @@ -424,14 +447,16 @@ def get_all_executed_or_voided(tx_storage):
self.log.debug('node1 rest', len_voided=len(tx_voided1), len_partial=len(tx_partial1))
self.log.debug('node2 rest', len_voided=len(tx_voided2), len_partial=len(tx_partial2))

def assertConsensusValid(self, manager):
def assertConsensusValid(self, manager: HathorManager) -> None:
for tx in manager.tx_storage.get_all_transactions():
if tx.is_block:
assert isinstance(tx, Block)
self.assertBlockConsensusValid(tx)
else:
assert isinstance(tx, Transaction)
self.assertTransactionConsensusValid(tx)

def assertBlockConsensusValid(self, block):
def assertBlockConsensusValid(self, block: Block) -> None:
self.assertTrue(block.is_block)
if not block.parents:
# Genesis
Expand All @@ -442,7 +467,8 @@ def assertBlockConsensusValid(self, block):
parent_meta = parent.get_metadata()
self.assertIsNone(parent_meta.voided_by)

def assertTransactionConsensusValid(self, tx):
def assertTransactionConsensusValid(self, tx: Transaction) -> None:
assert tx.storage is not None
self.assertFalse(tx.is_block)
meta = tx.get_metadata()
if meta.voided_by and tx.hash in meta.voided_by:
Expand All @@ -462,38 +488,40 @@ def assertTransactionConsensusValid(self, tx):
spent_meta = spent_tx.get_metadata()

if spent_meta.voided_by is not None:
self.assertIsNotNone(meta.voided_by)
assert meta.voided_by is not None
self.assertTrue(spent_meta.voided_by)
self.assertTrue(meta.voided_by)
self.assertTrue(spent_meta.voided_by.issubset(meta.voided_by))

for parent in tx.get_parents():
parent_meta = parent.get_metadata()
if parent_meta.voided_by is not None:
self.assertIsNotNone(meta.voided_by)
assert meta.voided_by is not None
self.assertTrue(parent_meta.voided_by)
self.assertTrue(meta.voided_by)
self.assertTrue(parent_meta.voided_by.issubset(meta.voided_by))

def assertSyncedProgress(self, node_sync):
def assertSyncedProgress(self, node_sync: NodeSyncTimestamp | NodeBlockSync) -> None:
"""Check "synced" status of p2p-manager, uses self._enable_sync_vX to choose which check to run."""
enable_sync_v1, enable_sync_v2 = self._syncVersionFlags()
if enable_sync_v2:
assert isinstance(node_sync, NodeBlockSync)
self.assertV2SyncedProgress(node_sync)
elif enable_sync_v1:
assert isinstance(node_sync, NodeSyncTimestamp)
self.assertV1SyncedProgress(node_sync)

def assertV1SyncedProgress(self, node_sync):
def assertV1SyncedProgress(self, node_sync: NodeSyncTimestamp) -> None:
self.assertEqual(node_sync.synced_timestamp, node_sync.peer_timestamp)

def assertV2SyncedProgress(self, node_sync):
def assertV2SyncedProgress(self, node_sync: NodeBlockSync) -> None:
self.assertEqual(node_sync.synced_block, node_sync.peer_best_block)

def clean_tmpdirs(self):
def clean_tmpdirs(self) -> None:
for tmpdir in self.tmpdirs:
shutil.rmtree(tmpdir)

def clean_pending(self, required_to_quiesce=True):
def clean_pending(self, required_to_quiesce: bool = True) -> None:
"""
This handy method cleans all pending tasks from the reactor.

Expand Down
Loading