Skip to content

Commit f3820a5

Browse files
committed
fix(ws): Add a graceful close mechanism to handle late messages and prevent errors
1 parent f3e3ab5 commit f3820a5

File tree

4 files changed

+133
-50
lines changed

4 files changed

+133
-50
lines changed

hathor/websocket/messages.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ class StreamErrorMessage(StreamBase):
4141
class StreamBeginMessage(StreamBase):
4242
type: str = Field('stream:history:begin', const=True)
4343
id: str
44+
seq: int
4445
window_size: Optional[int]
4546

4647

4748
class StreamEndMessage(StreamBase):
4849
type: str = Field('stream:history:end', const=True)
4950
id: str
51+
seq: int
5052

5153

5254
class StreamVertexMessage(StreamBase):

hathor/websocket/protocol.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def fail_if_history_streaming_is_disabled(self) -> bool:
144144

145145
def _create_streamer(self, stream_id: str, search: AddressSearch, window_size: int | None) -> None:
146146
"""Create the streamer and handle its callbacks."""
147+
assert self._history_streamer is None
147148
self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search)
148149
if window_size is not None:
149150
if window_size < 0:

hathor/websocket/streamer.py

+121-48
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from enum import Enum, auto
1516
from typing import TYPE_CHECKING, Optional
1617

1718
from twisted.internet.defer import Deferred
@@ -33,6 +34,27 @@
3334
from hathor.websocket.protocol import HathorAdminWebsocketProtocol
3435

3536

37+
class StreamerState(Enum):
38+
NOT_STARTED = auto()
39+
ACTIVE = auto()
40+
PAUSED = auto()
41+
CLOSING = auto()
42+
CLOSED = auto()
43+
44+
def can_transition_to(self, destination: 'StreamerState') -> bool:
45+
"""Checks if the transition to the destination state is valid."""
46+
return destination in VALID_TRANSITIONS[self]
47+
48+
49+
VALID_TRANSITIONS = {
50+
StreamerState.NOT_STARTED: {StreamerState.ACTIVE},
51+
StreamerState.ACTIVE: {StreamerState.PAUSED, StreamerState.CLOSING, StreamerState.CLOSED},
52+
StreamerState.PAUSED: {StreamerState.ACTIVE, StreamerState.CLOSING, StreamerState.CLOSED},
53+
StreamerState.CLOSING: {StreamerState.CLOSED},
54+
StreamerState.CLOSED: set()
55+
}
56+
57+
3658
@implementer(IPushProducer)
3759
class HistoryStreamer:
3860
"""A producer that pushes addresses and transactions to a websocket connection.
@@ -72,23 +94,34 @@ def __init__(self,
7294

7395
self.deferred: Deferred[bool] = Deferred()
7496

75-
# Statistics.
97+
# Statistics
98+
# ----------
7699
self.stats_log_interval = self.STATS_LOG_INTERVAL
77100
self.stats_total_messages: int = 0
78101
self.stats_sent_addresses: int = 0
79102
self.stats_sent_vertices: int = 0
80103

81-
# Execution control.
82-
self._started = False
83-
self._is_running = False
84-
self._paused = False
85-
self._stop = False
104+
# Execution control
105+
# -----------------
106+
self._state = StreamerState.NOT_STARTED
107+
# Used to mark that the streamer is currently running its main loop and sending messages.
108+
self._is_main_loop_running = False
109+
# Used to mark that the streamer was paused by the transport layer.
110+
self._is_paused_by_transport = False
86111

87-
# Flow control.
112+
# Flow control
113+
# ------------
88114
self._next_sequence_number: int = 0
89115
self._last_ack: int = -1
90116
self._sliding_window_size: Optional[int] = self.DEFAULT_SLIDING_WINDOW_SIZE
91117

118+
def get_next_seq(self) -> int:
119+
assert self._state is not StreamerState.CLOSING
120+
assert self._state is not StreamerState.CLOSED
121+
seq = self._next_sequence_number
122+
self._next_sequence_number += 1
123+
return seq
124+
92125
def set_sliding_window_size(self, size: Optional[int]) -> None:
93126
"""Set a new sliding window size for flow control. If size is none, disables flow control.
94127
"""
@@ -102,87 +135,130 @@ def set_ack(self, ack: int) -> None:
102135
103136
If the new value is bigger than the previous value, the streaming might be resumed.
104137
"""
105-
if ack <= self._last_ack:
138+
if ack == self._last_ack:
106139
# We might receive outdated or duplicate ACKs, and we can safely ignore them.
140+
return
141+
if ack < self._last_ack:
142+
# ACK got smaller. Something is wrong...
107143
self.send_message(StreamErrorMessage(
108144
id=self.stream_id,
109-
errmsg=f'Outdated ACK received. Skipping it... (ack={ack})'
145+
errmsg=f'Outdated ACK received (ack={ack})'
110146
))
147+
self.stop(False)
111148
return
112149
if ack >= self._next_sequence_number:
150+
# ACK is higher than the last message sent. Something is wrong...
113151
self.send_message(StreamErrorMessage(
114152
id=self.stream_id,
115-
errmsg=f'Received ACK is higher than the last sent message. Skipping it... (ack={ack})'
153+
errmsg=f'Received ACK is higher than the last sent message (ack={ack})'
116154
))
155+
self.stop(False)
117156
return
118157
self._last_ack = ack
119-
self.resume_if_possible()
158+
if self._state is not StreamerState.CLOSING:
159+
closing_ack = self._next_sequence_number - 1
160+
if ack == closing_ack:
161+
self.stop(True)
162+
else:
163+
self.resume_if_possible()
120164

121165
def resume_if_possible(self) -> None:
122-
if not self._started:
166+
"""Resume sending messages if possible."""
167+
if not self._state.can_transition_to(StreamerState.ACTIVE):
168+
return
169+
if self._is_main_loop_running:
170+
return
171+
if self._is_paused_by_transport:
123172
return
124-
if not self.should_pause_streaming() and not self._is_running:
125-
self.resumeProducing()
173+
if self.should_pause_streaming():
174+
return
175+
self._run()
126176

127177
def start(self) -> Deferred[bool]:
128178
"""Start streaming items."""
179+
assert self._state is StreamerState.NOT_STARTED
180+
129181
# The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object
130182
# which register a producer. It seems the HTTPChannel is not used anymore after switching
131183
# to websocket but it keep registered. So we have to unregister before registering a new
132184
# producer.
133185
if self.protocol.transport.producer:
134186
self.protocol.unregisterProducer()
135-
136187
self.protocol.registerProducer(self, True)
137188

138-
assert not self._started
139-
self._started = True
140-
self.send_message(StreamBeginMessage(id=self.stream_id, window_size=self._sliding_window_size))
141-
self.resumeProducing()
189+
self.send_message(StreamBeginMessage(
190+
id=self.stream_id,
191+
seq=self.get_next_seq(),
192+
window_size=self._sliding_window_size,
193+
))
194+
self.resume_if_possible()
142195
return self.deferred
143196

144197
def stop(self, success: bool) -> None:
145198
"""Stop streaming items."""
146-
assert self._started
147-
self._stop = True
148-
self._started = False
199+
if not self._state.can_transition_to(StreamerState.CLOSED):
200+
# Do nothing if the streamer has already been stopped.
201+
self.protocol.log.warn('stop called in an unexpected state', state=self._state)
202+
return
203+
self._state = StreamerState.CLOSED
149204
self.protocol.unregisterProducer()
150205
self.deferred.callback(success)
151206

207+
def gracefully_close(self) -> None:
208+
"""Gracefully close the stream by sending the StreamEndMessage and waiting for its ack."""
209+
if not self._state.can_transition_to(StreamerState.CLOSING):
210+
return
211+
self.send_message(StreamEndMessage(id=self.stream_id, seq=self.get_next_seq()))
212+
self._state = StreamerState.CLOSING
213+
152214
def pauseProducing(self) -> None:
153215
"""Pause streaming. Called by twisted."""
154-
self._paused = True
216+
if not self._state.can_transition_to(StreamerState.PAUSED):
217+
self.protocol.log.warn('pause requested in an unexpected state', state=self._state)
218+
return
219+
self._state = StreamerState.PAUSED
220+
self._is_paused_by_transport = True
155221

156222
def stopProducing(self) -> None:
157223
"""Stop streaming. Called by twisted."""
158-
self._stop = True
224+
if not self._state.can_transition_to(StreamerState.CLOSED):
225+
self.protocol.log.warn('stopped requested in an unexpected state', state=self._state)
226+
return
159227
self.stop(False)
160228

161229
def resumeProducing(self) -> None:
162230
"""Resume streaming. Called by twisted."""
163-
self._paused = False
164-
self._run()
165-
166-
def _run(self) -> None:
167-
"""Run the streaming main loop."""
168-
coro = self._async_run()
169-
Deferred.fromCoroutine(coro)
231+
if not self._state.can_transition_to(StreamerState.ACTIVE):
232+
self.protocol.log.warn('resume requested in an unexpected state', state=self._state)
233+
return
234+
self._is_paused_by_transport = False
235+
self.resume_if_possible()
170236

171237
def should_pause_streaming(self) -> bool:
238+
"""Return true if the streaming should pause due to the flow control mechanism."""
172239
if self._sliding_window_size is None:
173240
return False
174241
stop_value = self._last_ack + self._sliding_window_size + 1
175242
if self._next_sequence_number < stop_value:
176243
return False
177244
return True
178245

246+
def _run(self) -> None:
247+
"""Run the streaming main loop."""
248+
if not self._state.can_transition_to(StreamerState.ACTIVE):
249+
self.protocol.log.warn('_run() called in an unexpected state', state=self._state)
250+
return
251+
coro = self._async_run()
252+
Deferred.fromCoroutine(coro)
253+
179254
async def _async_run(self):
180-
assert not self._is_running
181-
self._is_running = True
255+
assert not self._is_main_loop_running
256+
self._state = StreamerState.ACTIVE
257+
self._is_main_loop_running = True
182258
try:
183259
await self._async_run_unsafe()
184260
finally:
185-
self._is_running = False
261+
self._is_main_loop_running = False
186262

187263
async def _async_run_unsafe(self):
188264
"""Internal method that runs the streaming main loop."""
@@ -204,7 +280,7 @@ async def _async_run_unsafe(self):
204280
self.stats_sent_addresses += 1
205281
self.send_message(StreamAddressMessage(
206282
id=self.stream_id,
207-
seq=self._next_sequence_number,
283+
seq=self.get_next_seq(),
208284
index=item.index,
209285
address=item.address,
210286
subscribed=subscribed,
@@ -214,42 +290,39 @@ async def _async_run_unsafe(self):
214290
self.stats_sent_vertices += 1
215291
self.send_message(StreamVertexMessage(
216292
id=self.stream_id,
217-
seq=self._next_sequence_number,
293+
seq=self.get_next_seq(),
218294
data=item.vertex.to_json_extended(),
219295
))
220296

221297
case _:
222298
assert False
223299

224-
self._next_sequence_number += 1
225300
if self.should_pause_streaming():
226301
break
227302

228-
# The methods `pauseProducing()` and `stopProducing()` might be called during the
229-
# call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change
230-
# during the loop.
231-
if self._paused or self._stop:
232-
break
233-
234303
self.stats_total_messages += 1
235304
if self.stats_total_messages % self.stats_log_interval == 0:
236305
self.protocol.log.info('websocket streaming statistics',
237306
total_messages=self.stats_total_messages,
238307
sent_vertices=self.stats_sent_vertices,
239308
sent_addresses=self.stats_sent_addresses)
240309

310+
# The methods `pauseProducing()` and `stopProducing()` might be called during the
311+
# call to `self.protocol.sendMessage()`. So the streamer state might change during
312+
# the loop.
313+
if self._state is not StreamerState.ACTIVE:
314+
break
315+
316+
# Limit blocking of the event loop to a maximum of N seconds.
241317
dt = self.reactor.seconds() - t0
242318
if dt > self.max_seconds_locking_event_loop:
243319
# Let the event loop run at least once.
244320
await deferLater(self.reactor, 0, lambda: None)
245321
t0 = self.reactor.seconds()
246322

247323
else:
248-
if self._stop:
249-
# If the streamer has been stopped, there is nothing else to do.
250-
return
251-
self.send_message(StreamEndMessage(id=self.stream_id))
252-
self.stop(True)
324+
# Iterator is empty so we can close the stream.
325+
self.gracefully_close()
253326

254327
def send_message(self, message: StreamBase) -> None:
255328
"""Send a message to the websocket connection."""

tests/websocket/test_streamer.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from hathor.wallet import HDWallet
77
from hathor.websocket.factory import HathorAdminWebsocketFactory
88
from hathor.websocket.iterators import AddressItem, ManualAddressSequencer, gap_limit_search
9-
from hathor.websocket.streamer import HistoryStreamer
9+
from hathor.websocket.streamer import HistoryStreamer, StreamerState
1010
from tests.unittest import TestCase
1111
from tests.utils import GENESIS_ADDRESS_B58
1212

@@ -60,7 +60,7 @@ def test_streamer(self) -> None:
6060
'data': genesis.to_json_extended(),
6161
})
6262
expected_result.append({'type': 'stream:history:end', 'id': stream_id})
63-
for index, item in enumerate(expected_result[1:-1]):
63+
for index, item in enumerate(expected_result):
6464
item['seq'] = index
6565

6666
# Create both the address iterator and the GAP limit searcher.
@@ -86,6 +86,13 @@ def test_streamer(self) -> None:
8686
# Run the streamer.
8787
manager.reactor.advance(10)
8888

89+
# Check the streamer is waiting for the last ACK.
90+
self.assertTrue(streamer._state, StreamerState.CLOSING)
91+
streamer.set_ack(1)
92+
self.assertTrue(streamer._state, StreamerState.CLOSING)
93+
streamer.set_ack(len(expected_result) - 1)
94+
self.assertTrue(streamer._state, StreamerState.CLOSED)
95+
8996
# Check the results.
9097
items_iter = self._parse_ws_raw(transport.value())
9198
result = list(items_iter)

0 commit comments

Comments
 (0)