diff --git a/hathor/websocket/messages.py b/hathor/websocket/messages.py index 01b3b4f45..86058759b 100644 --- a/hathor/websocket/messages.py +++ b/hathor/websocket/messages.py @@ -41,12 +41,14 @@ class StreamErrorMessage(StreamBase): class StreamBeginMessage(StreamBase): type: str = Field('stream:history:begin', const=True) id: str + seq: int window_size: Optional[int] class StreamEndMessage(StreamBase): type: str = Field('stream:history:end', const=True) id: str + seq: int class StreamVertexMessage(StreamBase): diff --git a/hathor/websocket/protocol.py b/hathor/websocket/protocol.py index 319eee555..e23d2b60a 100644 --- a/hathor/websocket/protocol.py +++ b/hathor/websocket/protocol.py @@ -144,6 +144,7 @@ def fail_if_history_streaming_is_disabled(self) -> bool: def _create_streamer(self, stream_id: str, search: AddressSearch, window_size: int | None) -> None: """Create the streamer and handle its callbacks.""" + assert self._history_streamer is None self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search) if window_size is not None: if window_size < 0: diff --git a/hathor/websocket/streamer.py b/hathor/websocket/streamer.py index 828d12472..08eb6ca89 100644 --- a/hathor/websocket/streamer.py +++ b/hathor/websocket/streamer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum, auto from typing import TYPE_CHECKING, Optional from twisted.internet.defer import Deferred @@ -33,6 +34,27 @@ from hathor.websocket.protocol import HathorAdminWebsocketProtocol +class StreamerState(Enum): + NOT_STARTED = auto() + ACTIVE = auto() + PAUSED = auto() + CLOSING = auto() + CLOSED = auto() + + def can_transition_to(self, destination: 'StreamerState') -> bool: + """Checks if the transition to the destination state is valid.""" + return destination in VALID_TRANSITIONS[self] + + +VALID_TRANSITIONS = { + StreamerState.NOT_STARTED: {StreamerState.ACTIVE}, + StreamerState.ACTIVE: {StreamerState.ACTIVE, StreamerState.PAUSED, StreamerState.CLOSING, StreamerState.CLOSED}, + StreamerState.PAUSED: {StreamerState.ACTIVE, StreamerState.CLOSED}, + StreamerState.CLOSING: {StreamerState.CLOSED}, + StreamerState.CLOSED: set() +} + + @implementer(IPushProducer) class HistoryStreamer: """A producer that pushes addresses and transactions to a websocket connection. @@ -72,23 +94,32 @@ def __init__(self, self.deferred: Deferred[bool] = Deferred() - # Statistics. + # Statistics + # ---------- self.stats_log_interval = self.STATS_LOG_INTERVAL self.stats_total_messages: int = 0 self.stats_sent_addresses: int = 0 self.stats_sent_vertices: int = 0 - # Execution control. - self._started = False - self._is_running = False - self._paused = False - self._stop = False + # Execution control + # ----------------- + self._state = StreamerState.NOT_STARTED + # Used to mark that the streamer is currently running its main loop and sending messages. + self._is_main_loop_running = False - # Flow control. + # Flow control + # ------------ self._next_sequence_number: int = 0 self._last_ack: int = -1 self._sliding_window_size: Optional[int] = self.DEFAULT_SLIDING_WINDOW_SIZE + def get_next_seq(self) -> int: + assert self._state is not StreamerState.CLOSING + assert self._state is not StreamerState.CLOSED + seq = self._next_sequence_number + self._next_sequence_number += 1 + return seq + def set_sliding_window_size(self, size: Optional[int]) -> None: """Set a new sliding window size for flow control. If size is none, disables flow control. """ @@ -102,73 +133,115 @@ def set_ack(self, ack: int) -> None: If the new value is bigger than the previous value, the streaming might be resumed. """ - if ack <= self._last_ack: + if self._state is StreamerState.CLOSING: + closing_ack = self._next_sequence_number - 1 + if ack == closing_ack: + self._last_ack = ack + self.stop(True) + return + if ack == self._last_ack: # We might receive outdated or duplicate ACKs, and we can safely ignore them. + return + if ack < self._last_ack: + # ACK got smaller. Something is wrong... self.send_message(StreamErrorMessage( id=self.stream_id, - errmsg=f'Outdated ACK received. Skipping it... (ack={ack})' + errmsg=f'Outdated ACK received (ack={ack})' )) + self.stop(False) return if ack >= self._next_sequence_number: + # ACK is higher than the last message sent. Something is wrong... self.send_message(StreamErrorMessage( id=self.stream_id, - errmsg=f'Received ACK is higher than the last sent message. Skipping it... (ack={ack})' + errmsg=f'Received ACK is higher than the last sent message (ack={ack})' )) + self.stop(False) return self._last_ack = ack self.resume_if_possible() def resume_if_possible(self) -> None: - if not self._started: + """Resume sending messages if possible.""" + if self._state is StreamerState.PAUSED: + return + if not self._state.can_transition_to(StreamerState.ACTIVE): + return + if self._is_main_loop_running: + return + if self.should_pause_streaming(): return - if not self.should_pause_streaming() and not self._is_running: - self.resumeProducing() + self._run() + + def set_state(self, new_state: StreamerState) -> None: + """Set a new state for the streamer.""" + if self._state == new_state: + return + assert self._state.can_transition_to(new_state) + self._state = new_state def start(self) -> Deferred[bool]: """Start streaming items.""" + assert self._state is StreamerState.NOT_STARTED + # The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object # which register a producer. It seems the HTTPChannel is not used anymore after switching # to websocket but it keep registered. So we have to unregister before registering a new # producer. if self.protocol.transport.producer: self.protocol.unregisterProducer() - self.protocol.registerProducer(self, True) - assert not self._started - self._started = True - self.send_message(StreamBeginMessage(id=self.stream_id, window_size=self._sliding_window_size)) - self.resumeProducing() + self.send_message(StreamBeginMessage( + id=self.stream_id, + seq=self.get_next_seq(), + window_size=self._sliding_window_size, + )) + self.resume_if_possible() return self.deferred def stop(self, success: bool) -> None: """Stop streaming items.""" - assert self._started - self._stop = True - self._started = False + if not self._state.can_transition_to(StreamerState.CLOSED): + # Do nothing if the streamer has already been stopped. + self.protocol.log.warn('stop called in an unexpected state', state=self._state) + return + self.set_state(StreamerState.CLOSED) self.protocol.unregisterProducer() self.deferred.callback(success) + def gracefully_close(self) -> None: + """Gracefully close the stream by sending the StreamEndMessage and waiting for its ack.""" + if not self._state.can_transition_to(StreamerState.CLOSING): + return + self.protocol.log.info('websocket streaming ended, waiting for ACK') + self.send_message(StreamEndMessage(id=self.stream_id, seq=self.get_next_seq())) + self.set_state(StreamerState.CLOSING) + def pauseProducing(self) -> None: """Pause streaming. Called by twisted.""" - self._paused = True + if not self._state.can_transition_to(StreamerState.PAUSED): + self.protocol.log.warn('pause requested in an unexpected state', state=self._state) + return + self.set_state(StreamerState.PAUSED) def stopProducing(self) -> None: """Stop streaming. Called by twisted.""" - self._stop = True + if not self._state.can_transition_to(StreamerState.CLOSED): + self.protocol.log.warn('stopped requested in an unexpected state', state=self._state) + return self.stop(False) def resumeProducing(self) -> None: """Resume streaming. Called by twisted.""" - self._paused = False - self._run() - - def _run(self) -> None: - """Run the streaming main loop.""" - coro = self._async_run() - Deferred.fromCoroutine(coro) + if not self._state.can_transition_to(StreamerState.ACTIVE): + self.protocol.log.warn('resume requested in an unexpected state', state=self._state) + return + self.set_state(StreamerState.ACTIVE) + self.resume_if_possible() def should_pause_streaming(self) -> bool: + """Return true if the streaming should pause due to the flow control mechanism.""" if self._sliding_window_size is None: return False stop_value = self._last_ack + self._sliding_window_size + 1 @@ -176,13 +249,22 @@ def should_pause_streaming(self) -> bool: return False return True + def _run(self) -> None: + """Run the streaming main loop.""" + if not self._state.can_transition_to(StreamerState.ACTIVE): + self.protocol.log.warn('_run() called in an unexpected state', state=self._state) + return + coro = self._async_run() + Deferred.fromCoroutine(coro) + async def _async_run(self): - assert not self._is_running - self._is_running = True + assert not self._is_main_loop_running + self.set_state(StreamerState.ACTIVE) + self._is_main_loop_running = True try: await self._async_run_unsafe() finally: - self._is_running = False + self._is_main_loop_running = False async def _async_run_unsafe(self): """Internal method that runs the streaming main loop.""" @@ -204,7 +286,7 @@ async def _async_run_unsafe(self): self.stats_sent_addresses += 1 self.send_message(StreamAddressMessage( id=self.stream_id, - seq=self._next_sequence_number, + seq=self.get_next_seq(), index=item.index, address=item.address, subscribed=subscribed, @@ -214,23 +296,16 @@ async def _async_run_unsafe(self): self.stats_sent_vertices += 1 self.send_message(StreamVertexMessage( id=self.stream_id, - seq=self._next_sequence_number, + seq=self.get_next_seq(), data=item.vertex.to_json_extended(), )) case _: assert False - self._next_sequence_number += 1 if self.should_pause_streaming(): break - # The methods `pauseProducing()` and `stopProducing()` might be called during the - # call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change - # during the loop. - if self._paused or self._stop: - break - self.stats_total_messages += 1 if self.stats_total_messages % self.stats_log_interval == 0: self.protocol.log.info('websocket streaming statistics', @@ -238,6 +313,13 @@ async def _async_run_unsafe(self): sent_vertices=self.stats_sent_vertices, sent_addresses=self.stats_sent_addresses) + # The methods `pauseProducing()` and `stopProducing()` might be called during the + # call to `self.protocol.sendMessage()`. So the streamer state might change during + # the loop. + if self._state is not StreamerState.ACTIVE: + break + + # Limit blocking of the event loop to a maximum of N seconds. dt = self.reactor.seconds() - t0 if dt > self.max_seconds_locking_event_loop: # Let the event loop run at least once. @@ -245,11 +327,8 @@ async def _async_run_unsafe(self): t0 = self.reactor.seconds() else: - if self._stop: - # If the streamer has been stopped, there is nothing else to do. - return - self.send_message(StreamEndMessage(id=self.stream_id)) - self.stop(True) + # Iterator is empty so we can close the stream. + self.gracefully_close() def send_message(self, message: StreamBase) -> None: """Send a message to the websocket connection.""" diff --git a/tests/websocket/test_streamer.py b/tests/websocket/test_streamer.py index e83a81438..87c4a5407 100644 --- a/tests/websocket/test_streamer.py +++ b/tests/websocket/test_streamer.py @@ -6,7 +6,7 @@ from hathor.wallet import HDWallet from hathor.websocket.factory import HathorAdminWebsocketFactory from hathor.websocket.iterators import AddressItem, ManualAddressSequencer, gap_limit_search -from hathor.websocket.streamer import HistoryStreamer +from hathor.websocket.streamer import HistoryStreamer, StreamerState from tests.unittest import TestCase from tests.utils import GENESIS_ADDRESS_B58 @@ -60,7 +60,7 @@ def test_streamer(self) -> None: 'data': genesis.to_json_extended(), }) expected_result.append({'type': 'stream:history:end', 'id': stream_id}) - for index, item in enumerate(expected_result[1:-1]): + for index, item in enumerate(expected_result): item['seq'] = index # Create both the address iterator and the GAP limit searcher. @@ -86,6 +86,13 @@ def test_streamer(self) -> None: # Run the streamer. manager.reactor.advance(10) + # Check the streamer is waiting for the last ACK. + self.assertTrue(streamer._state, StreamerState.CLOSING) + streamer.set_ack(1) + self.assertTrue(streamer._state, StreamerState.CLOSING) + streamer.set_ack(len(expected_result) - 1) + self.assertTrue(streamer._state, StreamerState.CLOSED) + # Check the results. items_iter = self._parse_ws_raw(transport.value()) result = list(items_iter)