Skip to content

Commit af58faa

Browse files
Unit tests for streams manager (#45090)
1 parent 5bf11d9 commit af58faa

File tree

7 files changed

+412
-71
lines changed

7 files changed

+412
-71
lines changed

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ data class DestinationCatalog(
2525
}
2626
}
2727

28+
interface DestinationCatalogFactory {
29+
fun make(): DestinationCatalog
30+
}
31+
2832
@Factory
29-
class DestinationCatalogFactory(
33+
class DefaultDestinationCatalogFactory(
3034
private val catalog: ConfiguredAirbyteCatalog,
3135
private val streamFactory: DestinationStreamFactory
3236
) {

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ class DestinationMessageQueueWriter(
4343
/* If the input message represents a record. */
4444
is DestinationRecordMessage -> {
4545
val manager = streamsManager.getManager(message.stream)
46-
val index = manager.countRecordIn(sizeBytes)
4746
when (message) {
4847
/* If a data record */
4948
is DestinationRecord -> {
5049
val wrapped =
5150
StreamRecordWrapped(
52-
index = index,
51+
index = manager.countRecordIn(),
5352
sizeBytes = sizeBytes,
5453
record = message
5554
)
@@ -58,7 +57,7 @@ class DestinationMessageQueueWriter(
5857

5958
/* If an end-of-stream marker. */
6059
is DestinationStreamComplete -> {
61-
val wrapped = StreamCompleteWrapped(index)
60+
val wrapped = StreamCompleteWrapped(index = manager.countEndOfStream())
6261
messageQueue.getChannel(message.stream).send(wrapped)
6362
}
6463
}

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamsManager.kt

+82-43
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ import io.github.oshai.kotlinlogging.KotlinLogging
1515
import io.micronaut.context.annotation.Factory
1616
import jakarta.inject.Singleton
1717
import java.util.concurrent.ConcurrentHashMap
18-
import java.util.concurrent.CountDownLatch
18+
import java.util.concurrent.atomic.AtomicBoolean
1919
import java.util.concurrent.atomic.AtomicLong
20-
import kotlinx.coroutines.Dispatchers
21-
import kotlinx.coroutines.withContext
20+
import kotlinx.coroutines.channels.Channel
2221

2322
/** Manages the state of all streams in the destination. */
2423
interface StreamsManager {
24+
/** Get the manager for the given stream. Throws an exception if the stream is not found. */
2525
fun getManager(stream: DestinationStream): StreamManager
26-
suspend fun awaitAllStreamsComplete()
26+
27+
/** Suspend until all streams are closed. */
28+
suspend fun awaitAllStreamsClosed()
2729
}
2830

2931
class DefaultStreamsManager(
@@ -33,68 +35,98 @@ class DefaultStreamsManager(
3335
return streamManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream")
3436
}
3537

36-
override suspend fun awaitAllStreamsComplete() {
38+
override suspend fun awaitAllStreamsClosed() {
3739
streamManagers.forEach { (_, manager) -> manager.awaitStreamClosed() }
3840
}
3941
}
4042

4143
/** Manages the state of a single stream. */
4244
interface StreamManager {
43-
fun countRecordIn(sizeBytes: Long): Long
45+
/** Count incoming record and return the record's *index*. */
46+
fun countRecordIn(): Long
47+
48+
/**
49+
* Count the end-of-stream. Expect this exactly once. Expect no further `countRecordIn`, and
50+
* expect that `markClosed` will always occur after this.
51+
*/
52+
fun countEndOfStream(): Long
53+
54+
/**
55+
* Mark a checkpoint in the stream and return the current index and the number of records since
56+
* the last one.
57+
*
58+
* NOTE: Single-writer. If in the future multiple threads set checkpoints, this method should be
59+
* synchronized.
60+
*/
4461
fun markCheckpoint(): Pair<Long, Long>
62+
63+
/** Record that the given batch's state has been reached for the associated range(s). */
4564
fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>)
65+
66+
/**
67+
* True if all are true:
68+
* * all records have been seen (ie, we've counted an end-of-stream)
69+
* * a [Batch.State.COMPLETE] batch range has been seen covering every record
70+
*
71+
* Does NOT require that the stream be closed.
72+
*/
4673
fun isBatchProcessingComplete(): Boolean
74+
75+
/**
76+
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
77+
* implicitly true if they have all reached [Batch.State.COMPLETE].
78+
*/
4779
fun areRecordsPersistedUntil(index: Long): Boolean
4880

81+
/** Mark the stream as closed. This should only be called after all records have been read. */
4982
fun markClosed()
83+
84+
/** True if the stream has been marked as closed. */
5085
fun streamIsClosed(): Boolean
86+
87+
/** Suspend until the stream is closed. */
5188
suspend fun awaitStreamClosed()
5289
}
5390

54-
/**
55-
* Maintains a map of stream -> status metadata, and a map of batch state -> record ranges for which
56-
* that state has been reached.
57-
*
58-
* TODO: Log a detailed report of the stream status on a regular cadence.
59-
*/
6091
class DefaultStreamManager(
6192
val stream: DestinationStream,
6293
) : StreamManager {
6394
private val log = KotlinLogging.logger {}
6495

65-
data class StreamStatus(
66-
val recordCount: AtomicLong = AtomicLong(0),
67-
val totalBytes: AtomicLong = AtomicLong(0),
68-
val enqueuedSize: AtomicLong = AtomicLong(0),
69-
val lastCheckpoint: AtomicLong = AtomicLong(0L),
70-
val closedLatch: CountDownLatch = CountDownLatch(1),
71-
)
96+
private val recordCount = AtomicLong(0)
97+
private val lastCheckpoint = AtomicLong(0L)
98+
private val readIsClosed = AtomicBoolean(false)
99+
private val streamIsClosed = AtomicBoolean(false)
100+
private val closedLock = Channel<Unit>()
72101

73-
private val streamStatus: StreamStatus = StreamStatus()
74102
private val rangesState: ConcurrentHashMap<Batch.State, RangeSet<Long>> = ConcurrentHashMap()
75103

76104
init {
77105
Batch.State.entries.forEach { rangesState[it] = TreeRangeSet.create() }
78106
}
79107

80-
override fun countRecordIn(sizeBytes: Long): Long {
81-
val index = streamStatus.recordCount.getAndIncrement()
82-
streamStatus.totalBytes.addAndGet(sizeBytes)
83-
streamStatus.enqueuedSize.addAndGet(sizeBytes)
84-
return index
108+
override fun countRecordIn(): Long {
109+
if (readIsClosed.get()) {
110+
throw IllegalStateException("Stream is closed for reading")
111+
}
112+
113+
return recordCount.getAndIncrement()
114+
}
115+
116+
override fun countEndOfStream(): Long {
117+
if (readIsClosed.getAndSet(true)) {
118+
throw IllegalStateException("Stream is closed for reading")
119+
}
120+
121+
return recordCount.get()
85122
}
86123

87-
/**
88-
* Mark a checkpoint in the stream and return the current index and the number of records since
89-
* the last one.
90-
*/
91124
override fun markCheckpoint(): Pair<Long, Long> {
92-
val index = streamStatus.recordCount.get()
93-
val lastCheckpoint = streamStatus.lastCheckpoint.getAndSet(index)
125+
val index = recordCount.get()
126+
val lastCheckpoint = lastCheckpoint.getAndSet(index)
94127
return Pair(index, index - lastCheckpoint)
95128
}
96129

97-
/** Record that the given batch's state has been reached for the associated range(s). */
98130
override fun <B : Batch> updateBatchState(batch: BatchEnvelope<B>) {
99131
val stateRanges =
100132
rangesState[batch.batch.state]
@@ -112,37 +144,44 @@ class DefaultStreamManager(
112144
log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" }
113145
}
114146

115-
/** True if all records in [0, index] have reached the given state. */
147+
/** True if all records in `[0, index)` have reached the given state. */
116148
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean {
117-
118149
val completeRanges = rangesState[state]!!
119150
return completeRanges.encloses(Range.closedOpen(0L, index))
120151
}
121152

122-
/** True if all records have associated [Batch.State.COMPLETE] batches. */
123153
override fun isBatchProcessingComplete(): Boolean {
124-
return isProcessingCompleteForState(streamStatus.recordCount.get(), Batch.State.COMPLETE)
154+
/* If the stream hasn't been fully read, it can't be done. */
155+
if (!readIsClosed.get()) {
156+
return false
157+
}
158+
159+
return isProcessingCompleteForState(recordCount.get(), Batch.State.COMPLETE)
125160
}
126161

127-
/**
128-
* True if all records in [0, index] have at least reached [Batch.State.PERSISTED]. This is
129-
* implicitly true if they have all reached [Batch.State.COMPLETE].
130-
*/
131162
override fun areRecordsPersistedUntil(index: Long): Boolean {
132163
return isProcessingCompleteForState(index, Batch.State.PERSISTED) ||
133164
isProcessingCompleteForState(index, Batch.State.COMPLETE) // complete => persisted
134165
}
135166

136167
override fun markClosed() {
137-
streamStatus.closedLatch.countDown()
168+
if (!readIsClosed.get()) {
169+
throw IllegalStateException("Stream must be fully read before it can be closed")
170+
}
171+
172+
if (streamIsClosed.compareAndSet(false, true)) {
173+
closedLock.trySend(Unit)
174+
}
138175
}
139176

140177
override fun streamIsClosed(): Boolean {
141-
return streamStatus.closedLatch.count == 0L
178+
return streamIsClosed.get()
142179
}
143180

144181
override suspend fun awaitStreamClosed() {
145-
withContext(Dispatchers.IO) { streamStatus.closedLatch.await() }
182+
if (!streamIsClosed.get()) {
183+
closedLock.receive()
184+
}
146185
}
147186
}
148187

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/TeardownTask.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class TeardownTask(
3434
}
3535

3636
/** Ensure we don't run until all streams have completed */
37-
streamsManager.awaitAllStreamsComplete()
37+
streamsManager.awaitAllStreamsClosed()
3838

3939
destination.teardown()
4040
taskLauncher.stop()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
3+
*/
4+
5+
package io.airbyte.cdk.command
6+
7+
import io.micronaut.context.annotation.Factory
8+
import io.micronaut.context.annotation.Replaces
9+
import io.micronaut.context.annotation.Requires
10+
import jakarta.inject.Named
11+
import jakarta.inject.Singleton
12+
13+
@Factory
14+
@Replaces(factory = DestinationCatalogFactory::class)
15+
@Requires(env = ["test"])
16+
class MockCatalogFactory : DestinationCatalogFactory {
17+
companion object {
18+
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1"))
19+
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2"))
20+
}
21+
22+
@Singleton
23+
@Named("mockCatalog")
24+
override fun make(): DestinationCatalog {
25+
return DestinationCatalog(streams = listOf(stream1, stream2))
26+
}
27+
}

airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt

+13-23
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ import com.google.common.collect.Range
88
import com.google.common.collect.RangeSet
99
import com.google.common.collect.TreeRangeSet
1010
import io.airbyte.cdk.command.DestinationCatalog
11-
import io.airbyte.cdk.command.DestinationCatalogFactory
1211
import io.airbyte.cdk.command.DestinationStream
12+
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1
13+
import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2
1314
import io.airbyte.cdk.message.Batch
1415
import io.airbyte.cdk.message.BatchEnvelope
1516
import io.airbyte.cdk.message.MessageConverter
16-
import io.micronaut.context.annotation.Factory
1717
import io.micronaut.context.annotation.Prototype
18-
import io.micronaut.context.annotation.Replaces
1918
import io.micronaut.context.annotation.Requires
2019
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
2120
import jakarta.inject.Inject
21+
import jakarta.inject.Named
2222
import jakarta.inject.Singleton
2323
import java.util.function.Consumer
2424
import java.util.stream.Stream
@@ -29,25 +29,10 @@ import org.junit.jupiter.params.provider.Arguments
2929
import org.junit.jupiter.params.provider.ArgumentsProvider
3030
import org.junit.jupiter.params.provider.ArgumentsSource
3131

32-
@MicronautTest
32+
@MicronautTest(environments = ["StateManagerTest"])
3333
class StateManagerTest {
3434
@Inject lateinit var stateManager: TestStateManager
3535

36-
companion object {
37-
val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1"))
38-
val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2"))
39-
}
40-
41-
@Factory
42-
@Replaces(factory = DestinationCatalogFactory::class)
43-
class MockCatalogFactory {
44-
@Singleton
45-
@Requires(env = ["test"])
46-
fun make(): DestinationCatalog {
47-
return DestinationCatalog(streams = listOf(stream1, stream2))
48-
}
49-
}
50-
5136
/**
5237
* Test state messages.
5338
*
@@ -95,7 +80,11 @@ class StateManagerTest {
9580
class MockStreamManager : StreamManager {
9681
var persistedRanges: RangeSet<Long> = TreeRangeSet.create()
9782

98-
override fun countRecordIn(sizeBytes: Long): Long {
83+
override fun countRecordIn(): Long {
84+
throw NotImplementedError()
85+
}
86+
87+
override fun countEndOfStream(): Long {
9988
throw NotImplementedError()
10089
}
10190

@@ -129,7 +118,8 @@ class StateManagerTest {
129118
}
130119

131120
@Prototype
132-
class MockStreamsManager(catalog: DestinationCatalog) : StreamsManager {
121+
@Requires(env = ["StateManagerTest"])
122+
class MockStreamsManager(@Named("mockCatalog") catalog: DestinationCatalog) : StreamsManager {
133123
private val mockManagers = catalog.streams.associateWith { MockStreamManager() }
134124

135125
fun addPersistedRanges(stream: DestinationStream, ranges: List<Range<Long>>) {
@@ -141,14 +131,14 @@ class StateManagerTest {
141131
?: throw IllegalArgumentException("Stream not found: $stream")
142132
}
143133

144-
override suspend fun awaitAllStreamsComplete() {
134+
override suspend fun awaitAllStreamsClosed() {
145135
throw NotImplementedError()
146136
}
147137
}
148138

149139
@Prototype
150140
class TestStateManager(
151-
override val catalog: DestinationCatalog,
141+
@Named("mockCatalog") override val catalog: DestinationCatalog,
152142
override val streamsManager: MockStreamsManager,
153143
override val outputFactory: MessageConverter<MockStateIn, MockStateOut>,
154144
override val outputConsumer: MockOutputConsumer

0 commit comments

Comments
 (0)