Skip to content

refactor(indexes): Optimize RocksDBAddressIndex to handle pagination in O(log n) #978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion hathor/indexes/address_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ def get_from_address(self, address: str) -> list[bytes]:
raise NotImplementedError

@abstractmethod
def get_sorted_from_address(self, address: str) -> list[bytes]:
def get_sorted_from_address(self, address: str, tx_start: Optional[BaseTransaction] = None) -> Iterable[bytes]:
""" Get a sorted list of transaction hashes of an address

`tx_start` serves as a pagination marker, indicating the starting position for the iteration.
When tx_start is None, the iteration begins from the initial element.
"""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions hathor/indexes/memory_address_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def add_tx(self, tx: BaseTransaction) -> None:
def get_from_address(self, address: str) -> list[bytes]:
return list(self._get_from_key(address))

def get_sorted_from_address(self, address: str) -> list[bytes]:
return list(self._get_sorted_from_key(address))
def get_sorted_from_address(self, address: str, tx_start: Optional[BaseTransaction] = None) -> Iterable[bytes]:
return self._get_sorted_from_key(address, tx_start)

def is_address_empty(self, address: str) -> bool:
return self._is_key_empty(address)
13 changes: 10 additions & 3 deletions hathor/indexes/memory_tx_group_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import abstractmethod
from collections import defaultdict
from typing import Iterable, Sized, TypeVar
from typing import Iterable, Optional, Sized, TypeVar

from structlog import get_logger

Expand Down Expand Up @@ -63,8 +63,15 @@ def _get_from_key(self, key: KT) -> Iterable[bytes]:
for _, h in self.index[key]:
yield h

def _get_sorted_from_key(self, key: KT) -> Iterable[bytes]:
return [h for _, h in sorted(self.index[key])]
def _get_sorted_from_key(self, key: KT, tx_start: Optional[BaseTransaction] = None) -> Iterable[bytes]:
sorted_elements = sorted(self.index[key])
found = False
for _, h in sorted_elements:
if tx_start and h == tx_start.hash:
found = True

if found or not tx_start:
yield h

def _is_key_empty(self, key: KT) -> bool:
return not bool(self.index[key])
4 changes: 2 additions & 2 deletions hathor/indexes/rocksdb_address_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def add_tx(self, tx: BaseTransaction) -> None:
def get_from_address(self, address: str) -> list[bytes]:
return list(self._get_from_key(address))

def get_sorted_from_address(self, address: str) -> list[bytes]:
return list(self._get_sorted_from_key(address))
def get_sorted_from_address(self, address: str, tx_start: Optional[BaseTransaction] = None) -> Iterable[bytes]:
return self._get_sorted_from_key(address, tx_start)

def is_address_empty(self, address: str) -> bool:
return self._is_key_empty(address)
11 changes: 7 additions & 4 deletions hathor/indexes/rocksdb_tx_group_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,15 @@ def remove_tx(self, tx: BaseTransaction) -> None:
self._db.delete((self._cf, self._to_rocksdb_key(key, tx)))

def _get_from_key(self, key: KT) -> Iterable[bytes]:
return self._util_get_from_key(key)

def _get_sorted_from_key(self, key: KT, tx_start: Optional[BaseTransaction] = None) -> Iterable[bytes]:
return self._util_get_from_key(key, tx_start)

def _util_get_from_key(self, key: KT, tx: Optional[BaseTransaction] = None) -> Iterable[bytes]:
self.log.debug('seek to', key=key)
it = self._db.iterkeys(self._cf)
it.seek(self._to_rocksdb_key(key))
it.seek(self._to_rocksdb_key(key, tx))
for _cf, rocksdb_key in it:
key2, _, tx_hash = self._from_rocksdb_key(rocksdb_key)
if key2 != key:
Expand All @@ -119,9 +125,6 @@ def _get_from_key(self, key: KT) -> Iterable[bytes]:
yield tx_hash
self.log.debug('seek end')

def _get_sorted_from_key(self, key: KT) -> Iterable[bytes]:
return self._get_from_key(key)

def _is_key_empty(self, key: KT) -> bool:
self.log.debug('seek to', key=key)
it = self._db.iterkeys(self._cf)
Expand Down
10 changes: 7 additions & 3 deletions hathor/indexes/tx_group_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod
from typing import Generic, Iterable, Sized, TypeVar
from typing import Generic, Iterable, Optional, Sized, TypeVar

from structlog import get_logger

Expand Down Expand Up @@ -49,8 +49,12 @@ def _get_from_key(self, key: KT) -> Iterable[bytes]:
raise NotImplementedError

@abstractmethod
def _get_sorted_from_key(self, key: KT) -> Iterable[bytes]:
"""Get all transactions that have a given key, sorted by timestamp."""
def _get_sorted_from_key(self, key: KT, tx_start: Optional[BaseTransaction] = None) -> Iterable[bytes]:
"""Get all transactions that have a given key, sorted by timestamp.

`tx_start` serves as a pagination marker, indicating the starting position for the iteration.
When tx_start is None, the iteration begins from the initial element.
"""
raise NotImplementedError

@abstractmethod
Expand Down
34 changes: 13 additions & 21 deletions hathor/wallet/resources/thin_wallet/address_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from hathor.cli.openapi_files.register import register_resource
from hathor.conf.get_settings import get_global_settings
from hathor.crypto.util import decode_address
from hathor.transaction.storage.exceptions import TransactionDoesNotExist
from hathor.util import json_dumpb, json_loadb
from hathor.wallet.exceptions import InvalidAddress

Expand Down Expand Up @@ -166,12 +167,6 @@ def get_address_history(self, addresses: list[str], ref_hash: Optional[str]) ->

history = []
seen: set[bytes] = set()
# XXX In this algorithm we need to sort all transactions of an address
# and find one specific (in case of a pagination request)
# so if this address has many txs, this could become slow
# I've done some tests with 10k txs in one address and the request
# returned in less than 50ms, so we will move forward with it for now
# but this could be improved in the future
for idx, address in enumerate(addresses):
try:
decode_address(address)
Expand All @@ -181,31 +176,28 @@ def get_address_history(self, addresses: list[str], ref_hash: Optional[str]) ->
'message': 'The address {} is invalid'.format(address)
})

hashes = addresses_index.get_sorted_from_address(address)
start_index = 0
if ref_hash_bytes and idx == 0:
# It's not the first request, so we must continue from the hash
# but we do it only for the first address
tx = None
if ref_hash_bytes:
try:
# Find index where the hash is
start_index = hashes.index(ref_hash_bytes)
except ValueError:
# ref_hash is not in the list
tx = self.manager.tx_storage.get_transaction(ref_hash_bytes)
except TransactionDoesNotExist:
return json_dumpb({
'success': False,
'message': 'Hash {} is not a transaction from the address {}.'.format(ref_hash, address)
'message': 'Hash {} is not a transaction hash.'.format(ref_hash)
})

# Slice the hashes array from the start_index
to_iterate = hashes[start_index:]
# The address index returns an iterable that starts at `tx`.
hashes = addresses_index.get_sorted_from_address(address, tx)
did_break = False
for index, tx_hash in enumerate(to_iterate):
for tx_hash in hashes:
if total_added == self._settings.MAX_TX_ADDRESSES_HISTORY:
# If already added the max number of elements possible, then break
# I need to add this if at the beginning of the loop to handle the case
# when the first tx of the address exceeds the limit, so we must return
# that the next request should start in the first tx of this address
did_break = True
# Saving the first tx hash for the next request
first_hash = tx_hash.hex()
break

if tx_hash not in seen:
Expand All @@ -216,6 +208,8 @@ def get_address_history(self, addresses: list[str], ref_hash: Optional[str]) ->
# It's important to validate also the maximum number of inputs and outputs because some txs
# are really big and the response payload becomes too big
did_break = True
# Saving the first tx hash for the next request
first_hash = tx_hash.hex()
break

seen.add(tx_hash)
Expand All @@ -226,10 +220,8 @@ def get_address_history(self, addresses: list[str], ref_hash: Optional[str]) ->
if did_break:
# We stopped in the middle of the txs of this address
# So we return that we still have more data to send
break_index = start_index + index
has_more = True
# The hash to start the search and which address this hash belongs
first_hash = hashes[break_index].hex()
first_address = address
break

Expand Down
6 changes: 3 additions & 3 deletions tests/tx/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_addresses_index_empty(self):
address = self.get_address(10)
assert address is not None
self.assertTrue(addresses_indexes.is_address_empty(address))
self.assertEqual(addresses_indexes.get_sorted_from_address(address), [])
self.assertEqual(list(addresses_indexes.get_sorted_from_address(address)), [])

def test_addresses_index_last(self):
"""
Expand All @@ -653,7 +653,7 @@ def test_addresses_index_last(self):
# XXX: this artificial address should major (be greater byte-wise) any possible "natural" address
address = '\x7f' * 34
self.assertTrue(addresses_indexes.is_address_empty(address))
self.assertEqual(addresses_indexes.get_sorted_from_address(address), [])
self.assertEqual(list(addresses_indexes.get_sorted_from_address(address)), [])

# XXX: since we didn't add any multisig address, this is guaranteed to be reach the tail end of the index
assert self._settings.P2PKH_VERSION_BYTE[0] < self._settings.MULTISIG_VERSION_BYTE[0]
Expand All @@ -666,7 +666,7 @@ def test_addresses_index_last(self):
assert address is not None

self.assertTrue(addresses_indexes.is_address_empty(address))
self.assertEqual(addresses_indexes.get_sorted_from_address(address), [])
self.assertEqual(list(addresses_indexes.get_sorted_from_address(address)), [])

def test_height_index(self):
from hathor.indexes.height_index import HeightInfo
Expand Down