Skip to content

fix(ws): Add a graceful close mechanism to handle late messages and prevent errors #1129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hathor/websocket/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class StreamErrorMessage(StreamBase):
class StreamBeginMessage(StreamBase):
type: str = Field('stream:history:begin', const=True)
id: str
seq: int
window_size: Optional[int]


class StreamEndMessage(StreamBase):
type: str = Field('stream:history:end', const=True)
id: str
seq: int


class StreamVertexMessage(StreamBase):
Expand Down
1 change: 1 addition & 0 deletions hathor/websocket/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def fail_if_history_streaming_is_disabled(self) -> bool:

def _create_streamer(self, stream_id: str, search: AddressSearch, window_size: int | None) -> None:
"""Create the streamer and handle its callbacks."""
assert self._history_streamer is None
self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search)
if window_size is not None:
if window_size < 0:
Expand Down
173 changes: 126 additions & 47 deletions hathor/websocket/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum, auto
from typing import TYPE_CHECKING, Optional

from twisted.internet.defer import Deferred
Expand All @@ -33,6 +34,27 @@
from hathor.websocket.protocol import HathorAdminWebsocketProtocol


class StreamerState(Enum):
NOT_STARTED = auto()
ACTIVE = auto()
PAUSED = auto()
CLOSING = auto()
CLOSED = auto()

def can_transition_to(self, destination: 'StreamerState') -> bool:
"""Checks if the transition to the destination state is valid."""
return destination in VALID_TRANSITIONS[self]


VALID_TRANSITIONS = {
StreamerState.NOT_STARTED: {StreamerState.ACTIVE},
StreamerState.ACTIVE: {StreamerState.ACTIVE, StreamerState.PAUSED, StreamerState.CLOSING, StreamerState.CLOSED},
StreamerState.PAUSED: {StreamerState.ACTIVE, StreamerState.CLOSED},
StreamerState.CLOSING: {StreamerState.CLOSED},
StreamerState.CLOSED: set()
}


@implementer(IPushProducer)
class HistoryStreamer:
"""A producer that pushes addresses and transactions to a websocket connection.
Expand Down Expand Up @@ -72,23 +94,32 @@ def __init__(self,

self.deferred: Deferred[bool] = Deferred()

# Statistics.
# Statistics
# ----------
self.stats_log_interval = self.STATS_LOG_INTERVAL
self.stats_total_messages: int = 0
self.stats_sent_addresses: int = 0
self.stats_sent_vertices: int = 0

# Execution control.
self._started = False
self._is_running = False
self._paused = False
self._stop = False
# Execution control
# -----------------
self._state = StreamerState.NOT_STARTED
# Used to mark that the streamer is currently running its main loop and sending messages.
self._is_main_loop_running = False

# Flow control.
# Flow control
# ------------
self._next_sequence_number: int = 0
self._last_ack: int = -1
self._sliding_window_size: Optional[int] = self.DEFAULT_SLIDING_WINDOW_SIZE

def get_next_seq(self) -> int:
assert self._state is not StreamerState.CLOSING
assert self._state is not StreamerState.CLOSED
seq = self._next_sequence_number
self._next_sequence_number += 1
return seq

def set_sliding_window_size(self, size: Optional[int]) -> None:
"""Set a new sliding window size for flow control. If size is none, disables flow control.
"""
Expand All @@ -102,87 +133,138 @@ def set_ack(self, ack: int) -> None:

If the new value is bigger than the previous value, the streaming might be resumed.
"""
if ack <= self._last_ack:
if self._state is StreamerState.CLOSING:
closing_ack = self._next_sequence_number - 1
if ack == closing_ack:
self._last_ack = ack
self.stop(True)
return
if ack == self._last_ack:
# We might receive outdated or duplicate ACKs, and we can safely ignore them.
return
if ack < self._last_ack:
# ACK got smaller. Something is wrong...
self.send_message(StreamErrorMessage(
id=self.stream_id,
errmsg=f'Outdated ACK received. Skipping it... (ack={ack})'
errmsg=f'Outdated ACK received (ack={ack})'
))
self.stop(False)
return
if ack >= self._next_sequence_number:
# ACK is higher than the last message sent. Something is wrong...
self.send_message(StreamErrorMessage(
id=self.stream_id,
errmsg=f'Received ACK is higher than the last sent message. Skipping it... (ack={ack})'
errmsg=f'Received ACK is higher than the last sent message (ack={ack})'
))
self.stop(False)
return
self._last_ack = ack
self.resume_if_possible()

def resume_if_possible(self) -> None:
if not self._started:
"""Resume sending messages if possible."""
if self._state is StreamerState.PAUSED:
return
if not self._state.can_transition_to(StreamerState.ACTIVE):
return
if self._is_main_loop_running:
return
if self.should_pause_streaming():
return
if not self.should_pause_streaming() and not self._is_running:
self.resumeProducing()
self._run()

def set_state(self, new_state: StreamerState) -> None:
"""Set a new state for the streamer."""
if self._state == new_state:
return
assert self._state.can_transition_to(new_state)
self._state = new_state

def start(self) -> Deferred[bool]:
"""Start streaming items."""
assert self._state is StreamerState.NOT_STARTED

# The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object
# which register a producer. It seems the HTTPChannel is not used anymore after switching
# to websocket but it keep registered. So we have to unregister before registering a new
# producer.
if self.protocol.transport.producer:
self.protocol.unregisterProducer()

self.protocol.registerProducer(self, True)

assert not self._started
self._started = True
self.send_message(StreamBeginMessage(id=self.stream_id, window_size=self._sliding_window_size))
self.resumeProducing()
self.send_message(StreamBeginMessage(
id=self.stream_id,
seq=self.get_next_seq(),
window_size=self._sliding_window_size,
))
self.resume_if_possible()
return self.deferred

def stop(self, success: bool) -> None:
"""Stop streaming items."""
assert self._started
self._stop = True
self._started = False
if not self._state.can_transition_to(StreamerState.CLOSED):
# Do nothing if the streamer has already been stopped.
self.protocol.log.warn('stop called in an unexpected state', state=self._state)
return
self.set_state(StreamerState.CLOSED)
self.protocol.unregisterProducer()
self.deferred.callback(success)

def gracefully_close(self) -> None:
"""Gracefully close the stream by sending the StreamEndMessage and waiting for its ack."""
if not self._state.can_transition_to(StreamerState.CLOSING):
return
self.protocol.log.info('websocket streaming ended, waiting for ACK')
self.send_message(StreamEndMessage(id=self.stream_id, seq=self.get_next_seq()))
self.set_state(StreamerState.CLOSING)

def pauseProducing(self) -> None:
"""Pause streaming. Called by twisted."""
self._paused = True
if not self._state.can_transition_to(StreamerState.PAUSED):
self.protocol.log.warn('pause requested in an unexpected state', state=self._state)
return
self.set_state(StreamerState.PAUSED)

def stopProducing(self) -> None:
"""Stop streaming. Called by twisted."""
self._stop = True
if not self._state.can_transition_to(StreamerState.CLOSED):
self.protocol.log.warn('stopped requested in an unexpected state', state=self._state)
return
self.stop(False)

def resumeProducing(self) -> None:
"""Resume streaming. Called by twisted."""
self._paused = False
self._run()

def _run(self) -> None:
"""Run the streaming main loop."""
coro = self._async_run()
Deferred.fromCoroutine(coro)
if not self._state.can_transition_to(StreamerState.ACTIVE):
self.protocol.log.warn('resume requested in an unexpected state', state=self._state)
return
self.set_state(StreamerState.ACTIVE)
self.resume_if_possible()

def should_pause_streaming(self) -> bool:
"""Return true if the streaming should pause due to the flow control mechanism."""
if self._sliding_window_size is None:
return False
stop_value = self._last_ack + self._sliding_window_size + 1
if self._next_sequence_number < stop_value:
return False
return True

def _run(self) -> None:
"""Run the streaming main loop."""
if not self._state.can_transition_to(StreamerState.ACTIVE):
self.protocol.log.warn('_run() called in an unexpected state', state=self._state)
return
coro = self._async_run()
Deferred.fromCoroutine(coro)

async def _async_run(self):
assert not self._is_running
self._is_running = True
assert not self._is_main_loop_running
self.set_state(StreamerState.ACTIVE)
self._is_main_loop_running = True
try:
await self._async_run_unsafe()
finally:
self._is_running = False
self._is_main_loop_running = False

async def _async_run_unsafe(self):
"""Internal method that runs the streaming main loop."""
Expand All @@ -204,7 +286,7 @@ async def _async_run_unsafe(self):
self.stats_sent_addresses += 1
self.send_message(StreamAddressMessage(
id=self.stream_id,
seq=self._next_sequence_number,
seq=self.get_next_seq(),
index=item.index,
address=item.address,
subscribed=subscribed,
Expand All @@ -214,42 +296,39 @@ async def _async_run_unsafe(self):
self.stats_sent_vertices += 1
self.send_message(StreamVertexMessage(
id=self.stream_id,
seq=self._next_sequence_number,
seq=self.get_next_seq(),
data=item.vertex.to_json_extended(),
))

case _:
assert False

self._next_sequence_number += 1
if self.should_pause_streaming():
break

# The methods `pauseProducing()` and `stopProducing()` might be called during the
# call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change
# during the loop.
if self._paused or self._stop:
break

self.stats_total_messages += 1
if self.stats_total_messages % self.stats_log_interval == 0:
self.protocol.log.info('websocket streaming statistics',
total_messages=self.stats_total_messages,
sent_vertices=self.stats_sent_vertices,
sent_addresses=self.stats_sent_addresses)

# The methods `pauseProducing()` and `stopProducing()` might be called during the
# call to `self.protocol.sendMessage()`. So the streamer state might change during
# the loop.
if self._state is not StreamerState.ACTIVE:
break

# Limit blocking of the event loop to a maximum of N seconds.
dt = self.reactor.seconds() - t0
if dt > self.max_seconds_locking_event_loop:
# Let the event loop run at least once.
await deferLater(self.reactor, 0, lambda: None)
t0 = self.reactor.seconds()

else:
if self._stop:
# If the streamer has been stopped, there is nothing else to do.
return
self.send_message(StreamEndMessage(id=self.stream_id))
self.stop(True)
# Iterator is empty so we can close the stream.
self.gracefully_close()

def send_message(self, message: StreamBase) -> None:
"""Send a message to the websocket connection."""
Expand Down
11 changes: 9 additions & 2 deletions tests/websocket/test_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hathor.wallet import HDWallet
from hathor.websocket.factory import HathorAdminWebsocketFactory
from hathor.websocket.iterators import AddressItem, ManualAddressSequencer, gap_limit_search
from hathor.websocket.streamer import HistoryStreamer
from hathor.websocket.streamer import HistoryStreamer, StreamerState
from tests.unittest import TestCase
from tests.utils import GENESIS_ADDRESS_B58

Expand Down Expand Up @@ -60,7 +60,7 @@ def test_streamer(self) -> None:
'data': genesis.to_json_extended(),
})
expected_result.append({'type': 'stream:history:end', 'id': stream_id})
for index, item in enumerate(expected_result[1:-1]):
for index, item in enumerate(expected_result):
item['seq'] = index

# Create both the address iterator and the GAP limit searcher.
Expand All @@ -86,6 +86,13 @@ def test_streamer(self) -> None:
# Run the streamer.
manager.reactor.advance(10)

# Check the streamer is waiting for the last ACK.
self.assertTrue(streamer._state, StreamerState.CLOSING)
streamer.set_ack(1)
self.assertTrue(streamer._state, StreamerState.CLOSING)
streamer.set_ack(len(expected_result) - 1)
self.assertTrue(streamer._state, StreamerState.CLOSED)

# Check the results.
items_iter = self._parse_ws_raw(transport.value())
result = list(items_iter)
Expand Down
Loading