Skip to content

Commit 8721dc5

Browse files
committed
fix(ws): Add a graceful close mechanism to handle late messages and prevent errors
1 parent ec2563e commit 8721dc5

File tree

4 files changed

+137
-50
lines changed

4 files changed

+137
-50
lines changed

hathor/websocket/messages.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ class StreamErrorMessage(StreamBase):
4141
class StreamBeginMessage(StreamBase):
4242
type: str = Field('stream:history:begin', const=True)
4343
id: str
44+
seq: int
4445
window_size: Optional[int]
4546

4647

4748
class StreamEndMessage(StreamBase):
4849
type: str = Field('stream:history:end', const=True)
4950
id: str
51+
seq: int
5052

5153

5254
class StreamVertexMessage(StreamBase):

hathor/websocket/protocol.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def fail_if_history_streaming_is_disabled(self) -> bool:
144144

145145
def _create_streamer(self, stream_id: str, search: AddressSearch, window_size: int | None) -> None:
146146
"""Create the streamer and handle its callbacks."""
147+
assert self._history_streamer is None
147148
self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search)
148149
if window_size is not None:
149150
if window_size < 0:

hathor/websocket/streamer.py

+125-48
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from enum import Enum, auto
1516
from typing import TYPE_CHECKING, Optional
1617

1718
from twisted.internet.defer import Deferred
@@ -33,6 +34,27 @@
3334
from hathor.websocket.protocol import HathorAdminWebsocketProtocol
3435

3536

37+
class StreamerState(Enum):
38+
NOT_STARTED = auto()
39+
ACTIVE = auto()
40+
PAUSED = auto()
41+
CLOSING = auto()
42+
CLOSED = auto()
43+
44+
def can_transition_to(self, destination: 'StreamerState') -> bool:
45+
"""Checks if the transition to the destination state is valid."""
46+
return destination in VALID_TRANSITIONS[self]
47+
48+
49+
VALID_TRANSITIONS = {
50+
StreamerState.NOT_STARTED: {StreamerState.ACTIVE},
51+
StreamerState.ACTIVE: {StreamerState.PAUSED, StreamerState.CLOSING, StreamerState.CLOSED},
52+
StreamerState.PAUSED: {StreamerState.ACTIVE, StreamerState.CLOSED},
53+
StreamerState.CLOSING: {StreamerState.CLOSED},
54+
StreamerState.CLOSED: set()
55+
}
56+
57+
3658
@implementer(IPushProducer)
3759
class HistoryStreamer:
3860
"""A producer that pushes addresses and transactions to a websocket connection.
@@ -72,23 +94,32 @@ def __init__(self,
7294

7395
self.deferred: Deferred[bool] = Deferred()
7496

75-
# Statistics.
97+
# Statistics
98+
# ----------
7699
self.stats_log_interval = self.STATS_LOG_INTERVAL
77100
self.stats_total_messages: int = 0
78101
self.stats_sent_addresses: int = 0
79102
self.stats_sent_vertices: int = 0
80103

81-
# Execution control.
82-
self._started = False
83-
self._is_running = False
84-
self._paused = False
85-
self._stop = False
104+
# Execution control
105+
# -----------------
106+
self._state = StreamerState.NOT_STARTED
107+
# Used to mark that the streamer is currently running its main loop and sending messages.
108+
self._is_main_loop_running = False
86109

87-
# Flow control.
110+
# Flow control
111+
# ------------
88112
self._next_sequence_number: int = 0
89113
self._last_ack: int = -1
90114
self._sliding_window_size: Optional[int] = self.DEFAULT_SLIDING_WINDOW_SIZE
91115

116+
def get_next_seq(self) -> int:
117+
assert self._state is not StreamerState.CLOSING
118+
assert self._state is not StreamerState.CLOSED
119+
seq = self._next_sequence_number
120+
self._next_sequence_number += 1
121+
return seq
122+
92123
def set_sliding_window_size(self, size: Optional[int]) -> None:
93124
"""Set a new sliding window size for flow control. If size is none, disables flow control.
94125
"""
@@ -102,87 +133,136 @@ def set_ack(self, ack: int) -> None:
102133
103134
If the new value is bigger than the previous value, the streaming might be resumed.
104135
"""
105-
if ack <= self._last_ack:
136+
if ack == self._last_ack:
106137
# We might receive outdated or duplicate ACKs, and we can safely ignore them.
138+
return
139+
if ack < self._last_ack:
140+
# ACK got smaller. Something is wrong...
107141
self.send_message(StreamErrorMessage(
108142
id=self.stream_id,
109-
errmsg=f'Outdated ACK received. Skipping it... (ack={ack})'
143+
errmsg=f'Outdated ACK received (ack={ack})'
110144
))
145+
self.stop(False)
111146
return
112147
if ack >= self._next_sequence_number:
148+
# ACK is higher than the last message sent. Something is wrong...
113149
self.send_message(StreamErrorMessage(
114150
id=self.stream_id,
115-
errmsg=f'Received ACK is higher than the last sent message. Skipping it... (ack={ack})'
151+
errmsg=f'Received ACK is higher than the last sent message (ack={ack})'
116152
))
153+
self.stop(False)
117154
return
118155
self._last_ack = ack
119-
self.resume_if_possible()
156+
if self._state is not StreamerState.CLOSING:
157+
closing_ack = self._next_sequence_number - 1
158+
if ack == closing_ack:
159+
self.stop(True)
160+
else:
161+
self.resume_if_possible()
120162

121163
def resume_if_possible(self) -> None:
122-
if not self._started:
164+
"""Resume sending messages if possible."""
165+
if self._state is StreamerState.PAUSED:
166+
return
167+
if not self._state.can_transition_to(StreamerState.ACTIVE):
168+
return
169+
if self._is_main_loop_running:
123170
return
124-
if not self.should_pause_streaming() and not self._is_running:
125-
self.resumeProducing()
171+
if self.should_pause_streaming():
172+
return
173+
self._run()
174+
175+
def set_state(self, new_state: StreamerState) -> None:
176+
"""Set a new state for the streamer."""
177+
if self._state == new_state:
178+
return
179+
assert self._state.can_transition_to(new_state)
180+
self._state = new_state
126181

127182
def start(self) -> Deferred[bool]:
128183
"""Start streaming items."""
184+
assert self._state is StreamerState.NOT_STARTED
185+
129186
# The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object
130187
# which register a producer. It seems the HTTPChannel is not used anymore after switching
131188
# to websocket but it keep registered. So we have to unregister before registering a new
132189
# producer.
133190
if self.protocol.transport.producer:
134191
self.protocol.unregisterProducer()
135-
136192
self.protocol.registerProducer(self, True)
137193

138-
assert not self._started
139-
self._started = True
140-
self.send_message(StreamBeginMessage(id=self.stream_id, window_size=self._sliding_window_size))
141-
self.resumeProducing()
194+
self.send_message(StreamBeginMessage(
195+
id=self.stream_id,
196+
seq=self.get_next_seq(),
197+
window_size=self._sliding_window_size,
198+
))
199+
self.resume_if_possible()
142200
return self.deferred
143201

144202
def stop(self, success: bool) -> None:
145203
"""Stop streaming items."""
146-
assert self._started
147-
self._stop = True
148-
self._started = False
204+
if not self._state.can_transition_to(StreamerState.CLOSED):
205+
# Do nothing if the streamer has already been stopped.
206+
self.protocol.log.warn('stop called in an unexpected state', state=self._state)
207+
return
208+
self.set_state(StreamerState.CLOSED)
149209
self.protocol.unregisterProducer()
150210
self.deferred.callback(success)
151211

212+
def gracefully_close(self) -> None:
213+
"""Gracefully close the stream by sending the StreamEndMessage and waiting for its ack."""
214+
if not self._state.can_transition_to(StreamerState.CLOSING):
215+
return
216+
self.send_message(StreamEndMessage(id=self.stream_id, seq=self.get_next_seq()))
217+
self.set_state(StreamerState.CLOSING)
218+
152219
def pauseProducing(self) -> None:
153220
"""Pause streaming. Called by twisted."""
154-
self._paused = True
221+
if not self._state.can_transition_to(StreamerState.PAUSED):
222+
self.protocol.log.warn('pause requested in an unexpected state', state=self._state)
223+
return
224+
self.set_state(StreamerState.PAUSED)
155225

156226
def stopProducing(self) -> None:
157227
"""Stop streaming. Called by twisted."""
158-
self._stop = True
228+
if not self._state.can_transition_to(StreamerState.CLOSED):
229+
self.protocol.log.warn('stopped requested in an unexpected state', state=self._state)
230+
return
159231
self.stop(False)
160232

161233
def resumeProducing(self) -> None:
162234
"""Resume streaming. Called by twisted."""
163-
self._paused = False
164-
self._run()
165-
166-
def _run(self) -> None:
167-
"""Run the streaming main loop."""
168-
coro = self._async_run()
169-
Deferred.fromCoroutine(coro)
235+
if not self._state.can_transition_to(StreamerState.ACTIVE):
236+
self.protocol.log.warn('resume requested in an unexpected state', state=self._state)
237+
return
238+
self.set_state(StreamerState.ACTIVE)
239+
self.resume_if_possible()
170240

171241
def should_pause_streaming(self) -> bool:
242+
"""Return true if the streaming should pause due to the flow control mechanism."""
172243
if self._sliding_window_size is None:
173244
return False
174245
stop_value = self._last_ack + self._sliding_window_size + 1
175246
if self._next_sequence_number < stop_value:
176247
return False
177248
return True
178249

250+
def _run(self) -> None:
251+
"""Run the streaming main loop."""
252+
if not self._state.can_transition_to(StreamerState.ACTIVE):
253+
self.protocol.log.warn('_run() called in an unexpected state', state=self._state)
254+
return
255+
coro = self._async_run()
256+
Deferred.fromCoroutine(coro)
257+
179258
async def _async_run(self):
180-
assert not self._is_running
181-
self._is_running = True
259+
assert not self._is_main_loop_running
260+
self.set_state(StreamerState.ACTIVE)
261+
self._is_main_loop_running = True
182262
try:
183263
await self._async_run_unsafe()
184264
finally:
185-
self._is_running = False
265+
self._is_main_loop_running = False
186266

187267
async def _async_run_unsafe(self):
188268
"""Internal method that runs the streaming main loop."""
@@ -204,7 +284,7 @@ async def _async_run_unsafe(self):
204284
self.stats_sent_addresses += 1
205285
self.send_message(StreamAddressMessage(
206286
id=self.stream_id,
207-
seq=self._next_sequence_number,
287+
seq=self.get_next_seq(),
208288
index=item.index,
209289
address=item.address,
210290
subscribed=subscribed,
@@ -214,42 +294,39 @@ async def _async_run_unsafe(self):
214294
self.stats_sent_vertices += 1
215295
self.send_message(StreamVertexMessage(
216296
id=self.stream_id,
217-
seq=self._next_sequence_number,
297+
seq=self.get_next_seq(),
218298
data=item.vertex.to_json_extended(),
219299
))
220300

221301
case _:
222302
assert False
223303

224-
self._next_sequence_number += 1
225304
if self.should_pause_streaming():
226305
break
227306

228-
# The methods `pauseProducing()` and `stopProducing()` might be called during the
229-
# call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change
230-
# during the loop.
231-
if self._paused or self._stop:
232-
break
233-
234307
self.stats_total_messages += 1
235308
if self.stats_total_messages % self.stats_log_interval == 0:
236309
self.protocol.log.info('websocket streaming statistics',
237310
total_messages=self.stats_total_messages,
238311
sent_vertices=self.stats_sent_vertices,
239312
sent_addresses=self.stats_sent_addresses)
240313

314+
# The methods `pauseProducing()` and `stopProducing()` might be called during the
315+
# call to `self.protocol.sendMessage()`. So the streamer state might change during
316+
# the loop.
317+
if self._state is not StreamerState.ACTIVE:
318+
break
319+
320+
# Limit blocking of the event loop to a maximum of N seconds.
241321
dt = self.reactor.seconds() - t0
242322
if dt > self.max_seconds_locking_event_loop:
243323
# Let the event loop run at least once.
244324
await deferLater(self.reactor, 0, lambda: None)
245325
t0 = self.reactor.seconds()
246326

247327
else:
248-
if self._stop:
249-
# If the streamer has been stopped, there is nothing else to do.
250-
return
251-
self.send_message(StreamEndMessage(id=self.stream_id))
252-
self.stop(True)
328+
# Iterator is empty so we can close the stream.
329+
self.gracefully_close()
253330

254331
def send_message(self, message: StreamBase) -> None:
255332
"""Send a message to the websocket connection."""

tests/websocket/test_streamer.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from hathor.wallet import HDWallet
77
from hathor.websocket.factory import HathorAdminWebsocketFactory
88
from hathor.websocket.iterators import AddressItem, ManualAddressSequencer, gap_limit_search
9-
from hathor.websocket.streamer import HistoryStreamer
9+
from hathor.websocket.streamer import HistoryStreamer, StreamerState
1010
from tests.unittest import TestCase
1111
from tests.utils import GENESIS_ADDRESS_B58
1212

@@ -60,7 +60,7 @@ def test_streamer(self) -> None:
6060
'data': genesis.to_json_extended(),
6161
})
6262
expected_result.append({'type': 'stream:history:end', 'id': stream_id})
63-
for index, item in enumerate(expected_result[1:-1]):
63+
for index, item in enumerate(expected_result):
6464
item['seq'] = index
6565

6666
# Create both the address iterator and the GAP limit searcher.
@@ -86,6 +86,13 @@ def test_streamer(self) -> None:
8686
# Run the streamer.
8787
manager.reactor.advance(10)
8888

89+
# Check the streamer is waiting for the last ACK.
90+
self.assertTrue(streamer._state, StreamerState.CLOSING)
91+
streamer.set_ack(1)
92+
self.assertTrue(streamer._state, StreamerState.CLOSING)
93+
streamer.set_ack(len(expected_result) - 1)
94+
self.assertTrue(streamer._state, StreamerState.CLOSED)
95+
8996
# Check the results.
9097
items_iter = self._parse_ws_raw(transport.value())
9198
result = list(items_iter)

0 commit comments

Comments
 (0)