@@ -15,15 +15,17 @@ import io.github.oshai.kotlinlogging.KotlinLogging
15
15
import io.micronaut.context.annotation.Factory
16
16
import jakarta.inject.Singleton
17
17
import java.util.concurrent.ConcurrentHashMap
18
- import java.util.concurrent.CountDownLatch
18
+ import java.util.concurrent.atomic.AtomicBoolean
19
19
import java.util.concurrent.atomic.AtomicLong
20
- import kotlinx.coroutines.Dispatchers
21
- import kotlinx.coroutines.withContext
20
+ import kotlinx.coroutines.channels.Channel
22
21
23
22
/* * Manages the state of all streams in the destination. */
24
23
interface StreamsManager {
24
+ /* * Get the manager for the given stream. Throws an exception if the stream is not found. */
25
25
fun getManager (stream : DestinationStream ): StreamManager
26
- suspend fun awaitAllStreamsComplete ()
26
+
27
+ /* * Suspend until all streams are closed. */
28
+ suspend fun awaitAllStreamsClosed ()
27
29
}
28
30
29
31
class DefaultStreamsManager (
@@ -33,68 +35,98 @@ class DefaultStreamsManager(
33
35
return streamManagers[stream] ? : throw IllegalArgumentException (" Stream not found: $stream " )
34
36
}
35
37
36
- override suspend fun awaitAllStreamsComplete () {
38
+ override suspend fun awaitAllStreamsClosed () {
37
39
streamManagers.forEach { (_, manager) -> manager.awaitStreamClosed() }
38
40
}
39
41
}
40
42
41
43
/* * Manages the state of a single stream. */
42
44
interface StreamManager {
43
- fun countRecordIn (sizeBytes : Long ): Long
45
+ /* * Count incoming record and return the record's *index*. */
46
+ fun countRecordIn (): Long
47
+
48
+ /* *
49
+ * Count the end-of-stream. Expect this exactly once. Expect no further `countRecordIn`, and
50
+ * expect that `markClosed` will always occur after this.
51
+ */
52
+ fun countEndOfStream (): Long
53
+
54
+ /* *
55
+ * Mark a checkpoint in the stream and return the current index and the number of records since
56
+ * the last one.
57
+ *
58
+ * NOTE: Single-writer. If in the future multiple threads set checkpoints, this method should be
59
+ * synchronized.
60
+ */
44
61
fun markCheckpoint (): Pair <Long , Long >
62
+
63
+ /* * Record that the given batch's state has been reached for the associated range(s). */
45
64
fun <B : Batch > updateBatchState (batch : BatchEnvelope <B >)
65
+
66
+ /* *
67
+ * True if all are true:
68
+ * * all records have been seen (ie, we've counted an end-of-stream)
69
+ * * a [Batch.State.COMPLETE] batch range has been seen covering every record
70
+ *
71
+ * Does NOT require that the stream be closed.
72
+ */
46
73
fun isBatchProcessingComplete (): Boolean
74
+
75
+ /* *
76
+ * True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
77
+ * implicitly true if they have all reached [Batch.State.COMPLETE].
78
+ */
47
79
fun areRecordsPersistedUntil (index : Long ): Boolean
48
80
81
+ /* * Mark the stream as closed. This should only be called after all records have been read. */
49
82
fun markClosed ()
83
+
84
+ /* * True if the stream has been marked as closed. */
50
85
fun streamIsClosed (): Boolean
86
+
87
+ /* * Suspend until the stream is closed. */
51
88
suspend fun awaitStreamClosed ()
52
89
}
53
90
54
- /* *
55
- * Maintains a map of stream -> status metadata, and a map of batch state -> record ranges for which
56
- * that state has been reached.
57
- *
58
- * TODO: Log a detailed report of the stream status on a regular cadence.
59
- */
60
91
class DefaultStreamManager (
61
92
val stream : DestinationStream ,
62
93
) : StreamManager {
63
94
private val log = KotlinLogging .logger {}
64
95
65
- data class StreamStatus (
66
- val recordCount : AtomicLong = AtomicLong (0),
67
- val totalBytes : AtomicLong = AtomicLong (0),
68
- val enqueuedSize : AtomicLong = AtomicLong (0),
69
- val lastCheckpoint : AtomicLong = AtomicLong (0L),
70
- val closedLatch : CountDownLatch = CountDownLatch (1),
71
- )
96
+ private val recordCount = AtomicLong (0 )
97
+ private val lastCheckpoint = AtomicLong (0L )
98
+ private val readIsClosed = AtomicBoolean (false )
99
+ private val streamIsClosed = AtomicBoolean (false )
100
+ private val closedLock = Channel <Unit >()
72
101
73
- private val streamStatus: StreamStatus = StreamStatus ()
74
102
private val rangesState: ConcurrentHashMap <Batch .State , RangeSet <Long >> = ConcurrentHashMap ()
75
103
76
104
init {
77
105
Batch .State .entries.forEach { rangesState[it] = TreeRangeSet .create() }
78
106
}
79
107
80
- override fun countRecordIn (sizeBytes : Long ): Long {
81
- val index = streamStatus.recordCount.getAndIncrement()
82
- streamStatus.totalBytes.addAndGet(sizeBytes)
83
- streamStatus.enqueuedSize.addAndGet(sizeBytes)
84
- return index
108
+ override fun countRecordIn (): Long {
109
+ if (readIsClosed.get()) {
110
+ throw IllegalStateException (" Stream is closed for reading" )
111
+ }
112
+
113
+ return recordCount.getAndIncrement()
114
+ }
115
+
116
+ override fun countEndOfStream (): Long {
117
+ if (readIsClosed.getAndSet(true )) {
118
+ throw IllegalStateException (" Stream is closed for reading" )
119
+ }
120
+
121
+ return recordCount.get()
85
122
}
86
123
87
- /* *
88
- * Mark a checkpoint in the stream and return the current index and the number of records since
89
- * the last one.
90
- */
91
124
override fun markCheckpoint (): Pair <Long , Long > {
92
- val index = streamStatus. recordCount.get()
93
- val lastCheckpoint = streamStatus. lastCheckpoint.getAndSet(index)
125
+ val index = recordCount.get()
126
+ val lastCheckpoint = lastCheckpoint.getAndSet(index)
94
127
return Pair (index, index - lastCheckpoint)
95
128
}
96
129
97
- /* * Record that the given batch's state has been reached for the associated range(s). */
98
130
override fun <B : Batch > updateBatchState (batch : BatchEnvelope <B >) {
99
131
val stateRanges =
100
132
rangesState[batch.batch.state]
@@ -112,37 +144,44 @@ class DefaultStreamManager(
112
144
log.info { " Updated ranges for $stream [${batch.batch.state} ]: $stateRanges " }
113
145
}
114
146
115
- /* * True if all records in [0, index] have reached the given state. */
147
+ /* * True if all records in ` [0, index)` have reached the given state. */
116
148
private fun isProcessingCompleteForState (index : Long , state : Batch .State ): Boolean {
117
-
118
149
val completeRanges = rangesState[state]!!
119
150
return completeRanges.encloses(Range .closedOpen(0L , index))
120
151
}
121
152
122
- /* * True if all records have associated [Batch.State.COMPLETE] batches. */
123
153
override fun isBatchProcessingComplete (): Boolean {
124
- return isProcessingCompleteForState(streamStatus.recordCount.get(), Batch .State .COMPLETE )
154
+ /* If the stream hasn't been fully read, it can't be done. */
155
+ if (! readIsClosed.get()) {
156
+ return false
157
+ }
158
+
159
+ return isProcessingCompleteForState(recordCount.get(), Batch .State .COMPLETE )
125
160
}
126
161
127
- /* *
128
- * True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
129
- * implicitly true if they have all reached [Batch.State.COMPLETE].
130
- */
131
162
override fun areRecordsPersistedUntil (index : Long ): Boolean {
132
163
return isProcessingCompleteForState(index, Batch .State .PERSISTED ) ||
133
164
isProcessingCompleteForState(index, Batch .State .COMPLETE ) // complete => persisted
134
165
}
135
166
136
167
override fun markClosed () {
137
- streamStatus.closedLatch.countDown()
168
+ if (! readIsClosed.get()) {
169
+ throw IllegalStateException (" Stream must be fully read before it can be closed" )
170
+ }
171
+
172
+ if (streamIsClosed.compareAndSet(false , true )) {
173
+ closedLock.trySend(Unit )
174
+ }
138
175
}
139
176
140
177
override fun streamIsClosed (): Boolean {
141
- return streamStatus.closedLatch.count == 0L
178
+ return streamIsClosed.get()
142
179
}
143
180
144
181
override suspend fun awaitStreamClosed () {
145
- withContext(Dispatchers .IO ) { streamStatus.closedLatch.await() }
182
+ if (! streamIsClosed.get()) {
183
+ closedLock.receive()
184
+ }
146
185
}
147
186
}
148
187
0 commit comments