@@ -53,6 +53,10 @@ class PeerState(Enum):
53
53
SYNCING_MEMPOOL = 'syncing-mempool'
54
54
55
55
56
+ class _GetDataOrigin (Enum ):
57
+ MEMPOOL = 'mempool'
58
+
59
+
56
60
class NodeBlockSync (SyncAgent ):
57
61
""" An algorithm to sync two peers based on their blockchain.
58
62
"""
@@ -1000,33 +1004,34 @@ def handle_transaction(self, payload: str) -> None:
1000
1004
self .log .debug ('tx streaming in progress' , txs_received = self ._tx_received )
1001
1005
1002
1006
@inlineCallbacks
1003
- def get_tx (self , tx_id : bytes ) -> Generator [Deferred , Any , BaseTransaction ]:
1007
+ def get_tx_mempool (self , tx_id : bytes ) -> Generator [Deferred , Any , BaseTransaction ]:
1004
1008
""" Async method to get a transaction from the db/cache or to download it.
1005
1009
"""
1010
+ assert self .state is PeerState .SYNCING_MEMPOOL , 'get_tx_mempool must only be called on mempool state'
1006
1011
tx = self ._get_tx_cache .get (tx_id )
1007
1012
if tx is not None :
1008
1013
self .log .debug ('tx in cache' , tx = tx_id .hex ())
1009
1014
return tx
1010
1015
try :
1011
1016
tx = self .tx_storage .get_transaction (tx_id )
1012
1017
except TransactionDoesNotExist :
1013
- tx = yield self .get_data (tx_id , 'mempool' )
1018
+ tx = yield self .get_data (tx_id , _GetDataOrigin . MEMPOOL )
1014
1019
assert tx is not None
1015
1020
if tx .hash != tx_id :
1016
1021
self .protocol .send_error_and_close_connection (f'DATA mempool { tx_id .hex ()} hash mismatch' )
1017
1022
raise
1018
1023
return tx
1019
1024
1020
- def get_data (self , tx_id : bytes , origin : str ) -> Deferred [BaseTransaction ]:
1025
+ def get_data (self , tx_id : bytes , origin : _GetDataOrigin ) -> Deferred [BaseTransaction ]:
1021
1026
""" Async method to request a tx by id.
1022
1027
"""
1023
1028
# TODO: deal with stale `get_data` calls
1024
- if origin != 'mempool' :
1029
+ if origin is not _GetDataOrigin . MEMPOOL :
1025
1030
raise ValueError (f'origin={ origin } not supported, only origin=mempool is supported' )
1026
1031
deferred = self ._deferred_txs .get (tx_id , None )
1027
1032
if deferred is None :
1028
1033
deferred = self ._deferred_txs [tx_id ] = Deferred ()
1029
- self .send_get_data (tx_id , origin = origin )
1034
+ self .send_get_data (tx_id , origin = origin . name )
1030
1035
self .log .debug ('get_data of new tx_id' , deferred = deferred , key = tx_id .hex ())
1031
1036
else :
1032
1037
# XXX: can we re-use deferred objects like this?
0 commit comments