Skip to content

Commit ded7fef

Browse files
authored
refactor(mypy): add stricter rules to unittest and utils (#974)
1 parent 466550d commit ded7fef

File tree

5 files changed

+128
-57
lines changed

5 files changed

+128
-57
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ module = [
165165
"tests.p2p.*",
166166
"tests.pubsub.*",
167167
"tests.simulation.*",
168+
"tests.unittest",
169+
"tests.utils",
168170
]
169171
disallow_untyped_defs = true
170172

tests/p2p/test_double_spending.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from hathor.manager import HathorManager
55
from hathor.simulator.utils import add_new_blocks
66
from hathor.transaction import Transaction
7+
from hathor.util import not_none
78
from tests import unittest
89
from tests.utils import add_blocks_unlock_reward, add_new_tx
910

@@ -23,7 +24,7 @@ def setUp(self) -> None:
2324
def _add_new_transactions(self, manager: HathorManager, num_txs: int) -> list[Transaction]:
2425
txs = []
2526
for _ in range(num_txs):
26-
address = self.get_address(0)
27+
address = not_none(self.get_address(0))
2728
value = self.rng.choice([5, 10, 15, 20])
2829
tx = add_new_tx(manager, address, value)
2930
txs.append(tx)

tests/tx/test_indexes2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_timestamp_index(self):
6464
# XXX: we verified they're the same, doesn't matter which we pick:
6565
idx = idx_memory
6666
hashes = hashes_memory
67-
self.log.debug('indexes match', idx=idx, hashes=unittest.shorten_hash(hashes))
67+
self.log.debug('indexes match', idx=idx, hashes=unittest.short_hashes(hashes))
6868
if idx is None:
6969
break
7070
offset_variety.add(idx[1])

tests/unittest.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shutil
44
import tempfile
55
import time
6-
from typing import Iterator, Optional
6+
from typing import Callable, Collection, Iterable, Iterator, Optional
77
from unittest import main as ut_main
88

99
from structlog import get_logger
@@ -16,13 +16,17 @@
1616
from hathor.daa import DifficultyAdjustmentAlgorithm, TestMode
1717
from hathor.event import EventManager
1818
from hathor.event.storage import EventStorage
19+
from hathor.manager import HathorManager
1920
from hathor.p2p.peer_id import PeerId
21+
from hathor.p2p.sync_v1.agent import NodeSyncTimestamp
22+
from hathor.p2p.sync_v2.agent import NodeBlockSync
2023
from hathor.p2p.sync_version import SyncVersion
2124
from hathor.pubsub import PubSubManager
2225
from hathor.reactor import ReactorProtocol as Reactor, get_global_reactor
2326
from hathor.simulator.clock import MemoryReactorHeapClock
24-
from hathor.transaction import BaseTransaction
27+
from hathor.transaction import BaseTransaction, Block, Transaction
2528
from hathor.transaction.storage.transaction_storage import TransactionStorage
29+
from hathor.types import VertexId
2630
from hathor.util import Random, not_none
2731
from hathor.wallet import BaseWallet, HDWallet, Wallet
2832
from tests.test_memory_reactor_clock import TestMemoryReactorClock
@@ -33,9 +37,8 @@
3337
USE_MEMORY_STORAGE = os.environ.get('HATHOR_TEST_MEMORY_STORAGE', 'false').lower() == 'true'
3438

3539

36-
def shorten_hash(container):
37-
container_type = type(container)
38-
return container_type(h[-2:].hex() for h in container)
40+
def short_hashes(container: Collection[bytes]) -> Iterable[str]:
41+
return map(lambda hash_bytes: hash_bytes[-2:].hex(), container)
3942

4043

4144
def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]:
@@ -50,7 +53,7 @@ def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]:
5053
yield PeerId.create_from_json(peer_id_dict)
5154

5255

53-
def _get_default_peer_id_pool_filepath():
56+
def _get_default_peer_id_pool_filepath() -> str:
5457
this_file_path = os.path.dirname(__file__)
5558
file_name = 'peer_id_pool.json'
5659
file_path = os.path.join(this_file_path, file_name)
@@ -109,19 +112,19 @@ class TestCase(unittest.TestCase):
109112
use_memory_storage: bool = USE_MEMORY_STORAGE
110113
seed_config: Optional[int] = None
111114

112-
def setUp(self):
113-
self.tmpdirs = []
115+
def setUp(self) -> None:
116+
self.tmpdirs: list[str] = []
114117
self.clock = TestMemoryReactorClock()
115118
self.clock.advance(time.time())
116119
self.log = logger.new()
117120
self.reset_peer_id_pool()
118121
self.seed = secrets.randbits(64) if self.seed_config is None else self.seed_config
119122
self.log.info('set seed', seed=self.seed)
120123
self.rng = Random(self.seed)
121-
self._pending_cleanups = []
124+
self._pending_cleanups: list[Callable] = []
122125
self._settings = get_global_settings()
123126

124-
def tearDown(self):
127+
def tearDown(self) -> None:
125128
self.clean_tmpdirs()
126129
for fn in self._pending_cleanups:
127130
fn()
@@ -144,12 +147,12 @@ def get_random_peer_id_from_pool(self, pool: Optional[list[PeerId]] = None,
144147
pool.remove(peer_id)
145148
return peer_id
146149

147-
def mkdtemp(self):
150+
def mkdtemp(self) -> str:
148151
tmpdir = tempfile.mkdtemp()
149152
self.tmpdirs.append(tmpdir)
150153
return tmpdir
151154

152-
def _create_test_wallet(self, unlocked=False):
155+
def _create_test_wallet(self, unlocked: bool = False) -> Wallet:
153156
""" Generate a Wallet with a number of keypairs for testing
154157
:rtype: Wallet
155158
"""
@@ -169,14 +172,14 @@ def get_builder(self, network: str) -> TestBuilder:
169172
.set_network(network)
170173
return builder
171174

172-
def create_peer_from_builder(self, builder, start_manager=True):
175+
def create_peer_from_builder(self, builder: Builder, start_manager: bool = True) -> HathorManager:
173176
artifacts = builder.build()
174177
manager = artifacts.manager
175178

176179
if artifacts.rocksdb_storage:
177180
self._pending_cleanups.append(artifacts.rocksdb_storage.close)
178181

179-
manager.avg_time_between_blocks = 0.0001
182+
# manager.avg_time_between_blocks = 0.0001 # FIXME: This property is not defined. Fix this.
180183

181184
if start_manager:
182185
manager.start()
@@ -277,7 +280,7 @@ def create_peer( # type: ignore[no-untyped-def]
277280

278281
return manager
279282

280-
def run_to_completion(self):
283+
def run_to_completion(self) -> None:
281284
""" This will advance the test's clock until all calls scheduled are done.
282285
"""
283286
for call in self.clock.getDelayedCalls():
@@ -300,7 +303,11 @@ def assertIsTopological(self, tx_sequence: Iterator[BaseTransaction], message: O
300303
self.assertIn(dep, valid_deps, message)
301304
valid_deps.add(tx.hash)
302305

303-
def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None):
306+
def _syncVersionFlags(
307+
self,
308+
enable_sync_v1: bool | None = None,
309+
enable_sync_v2: bool | None = None
310+
) -> tuple[bool, bool]:
304311
"""Internal: use this to check and get the flags and optionally provide override values."""
305312
if enable_sync_v1 is None:
306313
assert hasattr(self, '_enable_sync_v1'), ('`_enable_sync_v1` has no default by design, either set one on '
@@ -313,19 +320,19 @@ def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None):
313320
assert enable_sync_v1 or enable_sync_v2, 'enable at least one sync version'
314321
return enable_sync_v1, enable_sync_v2
315322

316-
def assertTipsEqual(self, manager1, manager2):
323+
def assertTipsEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
317324
_, enable_sync_v2 = self._syncVersionFlags()
318325
if enable_sync_v2:
319326
self.assertTipsEqualSyncV2(manager1, manager2)
320327
else:
321328
self.assertTipsEqualSyncV1(manager1, manager2)
322329

323-
def assertTipsNotEqual(self, manager1, manager2):
330+
def assertTipsNotEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
324331
s1 = set(manager1.tx_storage.get_all_tips())
325332
s2 = set(manager2.tx_storage.get_all_tips())
326333
self.assertNotEqual(s1, s2)
327334

328-
def assertTipsEqualSyncV1(self, manager1, manager2):
335+
def assertTipsEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None:
329336
# XXX: this is the original implementation of assertTipsEqual
330337
s1 = set(manager1.tx_storage.get_all_tips())
331338
s2 = set(manager2.tx_storage.get_all_tips())
@@ -335,39 +342,45 @@ def assertTipsEqualSyncV1(self, manager1, manager2):
335342
s2 = set(manager2.tx_storage.get_tx_tips())
336343
self.assertEqual(s1, s2)
337344

338-
def assertTipsEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True):
345+
def assertTipsEqualSyncV2(
346+
self,
347+
manager1: HathorManager,
348+
manager2: HathorManager,
349+
*,
350+
strict_sync_v2_indexes: bool = True
351+
) -> None:
339352
# tx tips
340353
if strict_sync_v2_indexes:
341-
tips1 = manager1.tx_storage.indexes.mempool_tips.get()
342-
tips2 = manager2.tx_storage.indexes.mempool_tips.get()
354+
tips1 = not_none(not_none(manager1.tx_storage.indexes).mempool_tips).get()
355+
tips2 = not_none(not_none(manager2.tx_storage.indexes).mempool_tips).get()
343356
else:
344357
tips1 = {tx.hash for tx in manager1.tx_storage.iter_mempool_tips_from_best_index()}
345358
tips2 = {tx.hash for tx in manager2.tx_storage.iter_mempool_tips_from_best_index()}
346-
self.log.debug('tx tips1', len=len(tips1), list=shorten_hash(tips1))
347-
self.log.debug('tx tips2', len=len(tips2), list=shorten_hash(tips2))
359+
self.log.debug('tx tips1', len=len(tips1), list=short_hashes(tips1))
360+
self.log.debug('tx tips2', len=len(tips2), list=short_hashes(tips2))
348361
self.assertEqual(tips1, tips2)
349362

350363
# best block
351364
s1 = set(manager1.tx_storage.get_best_block_tips())
352365
s2 = set(manager2.tx_storage.get_best_block_tips())
353-
self.log.debug('block tips1', len=len(s1), list=shorten_hash(s1))
354-
self.log.debug('block tips2', len=len(s2), list=shorten_hash(s2))
366+
self.log.debug('block tips1', len=len(s1), list=short_hashes(s1))
367+
self.log.debug('block tips2', len=len(s2), list=short_hashes(s2))
355368
self.assertEqual(s1, s2)
356369

357370
# best block (from height index)
358-
b1 = manager1.tx_storage.indexes.height.get_tip()
359-
b2 = manager2.tx_storage.indexes.height.get_tip()
371+
b1 = not_none(manager1.tx_storage.indexes).height.get_tip()
372+
b2 = not_none(manager2.tx_storage.indexes).height.get_tip()
360373
self.assertIn(b1, s2)
361374
self.assertIn(b2, s1)
362375

363-
def assertConsensusEqual(self, manager1, manager2):
376+
def assertConsensusEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
364377
_, enable_sync_v2 = self._syncVersionFlags()
365378
if enable_sync_v2:
366379
self.assertConsensusEqualSyncV2(manager1, manager2)
367380
else:
368381
self.assertConsensusEqualSyncV1(manager1, manager2)
369382

370-
def assertConsensusEqualSyncV1(self, manager1, manager2):
383+
def assertConsensusEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None:
371384
self.assertEqual(manager1.tx_storage.get_vertices_count(), manager2.tx_storage.get_vertices_count())
372385
for tx1 in manager1.tx_storage.get_all_transactions():
373386
tx2 = manager2.tx_storage.get_transaction(tx1.hash)
@@ -381,12 +394,20 @@ def assertConsensusEqualSyncV1(self, manager1, manager2):
381394
self.assertIsNone(tx2_meta.voided_by)
382395
else:
383396
# If tx1 is voided, then tx2 must be voided.
397+
assert tx1_meta.voided_by is not None
398+
assert tx2_meta.voided_by is not None
384399
self.assertGreaterEqual(len(tx1_meta.voided_by), 1)
385400
self.assertGreaterEqual(len(tx2_meta.voided_by), 1)
386401
# Hard verification
387402
# self.assertEqual(tx1_meta.voided_by, tx2_meta.voided_by)
388403

389-
def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True):
404+
def assertConsensusEqualSyncV2(
405+
self,
406+
manager1: HathorManager,
407+
manager2: HathorManager,
408+
*,
409+
strict_sync_v2_indexes: bool = True
410+
) -> None:
390411
# The current sync algorithm does not propagate voided blocks/txs
391412
# so the count might be different even though the consensus is equal
392413
# One peer might have voided txs that the other does not have
@@ -397,7 +418,9 @@ def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_index
397418
# the following is specific to sync-v2
398419

399420
# helper function:
400-
def get_all_executed_or_voided(tx_storage):
421+
def get_all_executed_or_voided(
422+
tx_storage: TransactionStorage
423+
) -> tuple[set[VertexId], set[VertexId], set[VertexId]]:
401424
"""Get all txs separated into three sets: executed, voided, partial"""
402425
tx_executed = set()
403426
tx_voided = set()
@@ -424,14 +447,16 @@ def get_all_executed_or_voided(tx_storage):
424447
self.log.debug('node1 rest', len_voided=len(tx_voided1), len_partial=len(tx_partial1))
425448
self.log.debug('node2 rest', len_voided=len(tx_voided2), len_partial=len(tx_partial2))
426449

427-
def assertConsensusValid(self, manager):
450+
def assertConsensusValid(self, manager: HathorManager) -> None:
428451
for tx in manager.tx_storage.get_all_transactions():
429452
if tx.is_block:
453+
assert isinstance(tx, Block)
430454
self.assertBlockConsensusValid(tx)
431455
else:
456+
assert isinstance(tx, Transaction)
432457
self.assertTransactionConsensusValid(tx)
433458

434-
def assertBlockConsensusValid(self, block):
459+
def assertBlockConsensusValid(self, block: Block) -> None:
435460
self.assertTrue(block.is_block)
436461
if not block.parents:
437462
# Genesis
@@ -442,7 +467,8 @@ def assertBlockConsensusValid(self, block):
442467
parent_meta = parent.get_metadata()
443468
self.assertIsNone(parent_meta.voided_by)
444469

445-
def assertTransactionConsensusValid(self, tx):
470+
def assertTransactionConsensusValid(self, tx: Transaction) -> None:
471+
assert tx.storage is not None
446472
self.assertFalse(tx.is_block)
447473
meta = tx.get_metadata()
448474
if meta.voided_by and tx.hash in meta.voided_by:
@@ -462,38 +488,40 @@ def assertTransactionConsensusValid(self, tx):
462488
spent_meta = spent_tx.get_metadata()
463489

464490
if spent_meta.voided_by is not None:
465-
self.assertIsNotNone(meta.voided_by)
491+
assert meta.voided_by is not None
466492
self.assertTrue(spent_meta.voided_by)
467493
self.assertTrue(meta.voided_by)
468494
self.assertTrue(spent_meta.voided_by.issubset(meta.voided_by))
469495

470496
for parent in tx.get_parents():
471497
parent_meta = parent.get_metadata()
472498
if parent_meta.voided_by is not None:
473-
self.assertIsNotNone(meta.voided_by)
499+
assert meta.voided_by is not None
474500
self.assertTrue(parent_meta.voided_by)
475501
self.assertTrue(meta.voided_by)
476502
self.assertTrue(parent_meta.voided_by.issubset(meta.voided_by))
477503

478-
def assertSyncedProgress(self, node_sync):
504+
def assertSyncedProgress(self, node_sync: NodeSyncTimestamp | NodeBlockSync) -> None:
479505
"""Check "synced" status of p2p-manager, uses self._enable_sync_vX to choose which check to run."""
480506
enable_sync_v1, enable_sync_v2 = self._syncVersionFlags()
481507
if enable_sync_v2:
508+
assert isinstance(node_sync, NodeBlockSync)
482509
self.assertV2SyncedProgress(node_sync)
483510
elif enable_sync_v1:
511+
assert isinstance(node_sync, NodeSyncTimestamp)
484512
self.assertV1SyncedProgress(node_sync)
485513

486-
def assertV1SyncedProgress(self, node_sync):
514+
def assertV1SyncedProgress(self, node_sync: NodeSyncTimestamp) -> None:
487515
self.assertEqual(node_sync.synced_timestamp, node_sync.peer_timestamp)
488516

489-
def assertV2SyncedProgress(self, node_sync):
517+
def assertV2SyncedProgress(self, node_sync: NodeBlockSync) -> None:
490518
self.assertEqual(node_sync.synced_block, node_sync.peer_best_block)
491519

492-
def clean_tmpdirs(self):
520+
def clean_tmpdirs(self) -> None:
493521
for tmpdir in self.tmpdirs:
494522
shutil.rmtree(tmpdir)
495523

496-
def clean_pending(self, required_to_quiesce=True):
524+
def clean_pending(self, required_to_quiesce: bool = True) -> None:
497525
"""
498526
This handy method cleans all pending tasks from the reactor.
499527

0 commit comments

Comments
 (0)