Skip to content

Commit 095d30e

Browse files
State manager test, rational range behavior throughout (#44945)
1 parent 2073513 commit 095d30e

File tree

6 files changed

+629
-132
lines changed

6 files changed

+629
-132
lines changed

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/AirbyteStateMessageFactory.kt

-75
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
3+
*/
4+
5+
package io.airbyte.cdk.message
6+
7+
import io.airbyte.protocol.models.v0.AirbyteGlobalState
8+
import io.airbyte.protocol.models.v0.AirbyteMessage
9+
import io.airbyte.protocol.models.v0.AirbyteStateMessage
10+
import io.airbyte.protocol.models.v0.AirbyteStateStats
11+
import io.airbyte.protocol.models.v0.AirbyteStreamState
12+
import io.airbyte.protocol.models.v0.StreamDescriptor
13+
import jakarta.inject.Singleton
14+
15+
/**
16+
* Converts the internal @[DestinationStateMessage] case class to the Protocol state messages
17+
* required by @[io.airbyte.cdk.output.OutputConsumer]
18+
*/
19+
interface MessageConverter<T, U> {
20+
fun from(message: T): U
21+
}
22+
23+
@Singleton
24+
class DefaultMessageConverter : MessageConverter<DestinationStateMessage, AirbyteMessage> {
25+
override fun from(message: DestinationStateMessage): AirbyteMessage {
26+
val state =
27+
when (message) {
28+
is DestinationStreamState ->
29+
AirbyteStateMessage()
30+
.withSourceStats(
31+
AirbyteStateStats()
32+
.withRecordCount(message.sourceStats.recordCount.toDouble())
33+
)
34+
.withDestinationStats(
35+
message.destinationStats?.let {
36+
AirbyteStateStats().withRecordCount(it.recordCount.toDouble())
37+
}
38+
?: throw IllegalStateException(
39+
"Destination stats must be provided for DestinationStreamState"
40+
)
41+
)
42+
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)
43+
.withStream(fromStreamState(message.streamState))
44+
is DestinationGlobalState ->
45+
AirbyteStateMessage()
46+
.withSourceStats(
47+
AirbyteStateStats()
48+
.withRecordCount(message.sourceStats.recordCount.toDouble())
49+
)
50+
.withDestinationStats(
51+
message.destinationStats?.let {
52+
AirbyteStateStats().withRecordCount(it.recordCount.toDouble())
53+
}
54+
)
55+
.withType(AirbyteStateMessage.AirbyteStateType.GLOBAL)
56+
.withGlobal(
57+
AirbyteGlobalState()
58+
.withSharedState(message.state)
59+
.withStreamStates(message.streamStates.map { fromStreamState(it) })
60+
)
61+
}
62+
return AirbyteMessage().withState(state)
63+
}
64+
65+
private fun fromStreamState(
66+
streamState: DestinationStateMessage.StreamState
67+
): AirbyteStreamState {
68+
return AirbyteStreamState()
69+
.withStreamDescriptor(
70+
StreamDescriptor()
71+
.withNamespace(streamState.stream.descriptor.namespace)
72+
.withName(streamState.stream.descriptor.name)
73+
)
74+
.withStreamState(streamState.state)
75+
}
76+
}

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class DestinationMessageQueueWriter(
3131
private val catalog: DestinationCatalog,
3232
private val messageQueue: MessageQueue<DestinationStream, DestinationRecordWrapped>,
3333
private val streamsManager: StreamsManager,
34-
private val stateManager: StateManager
34+
private val stateManager: StateManager<DestinationStream, DestinationStateMessage>
3535
) : MessageQueueWriter<DestinationMessage> {
3636
/**
3737
* Deserialize and route the message to the appropriate channel.

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt

+80-54
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,29 @@ package io.airbyte.cdk.state
66

77
import io.airbyte.cdk.command.DestinationCatalog
88
import io.airbyte.cdk.command.DestinationStream
9-
import io.airbyte.cdk.message.AirbyteStateMessageFactory
109
import io.airbyte.cdk.message.DestinationStateMessage
11-
import io.airbyte.cdk.output.OutputConsumer
10+
import io.airbyte.cdk.message.MessageConverter
11+
import io.airbyte.protocol.models.v0.AirbyteMessage
1212
import io.github.oshai.kotlinlogging.KotlinLogging
13+
import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap
1314
import jakarta.inject.Singleton
1415
import java.util.concurrent.ConcurrentHashMap
1516
import java.util.concurrent.ConcurrentLinkedQueue
1617
import java.util.concurrent.atomic.AtomicReference
18+
import java.util.function.Consumer
1719

1820
/**
1921
* Interface for state management. Should accept stream and global state, as well as requests to
2022
* flush all data-sufficient states.
2123
*/
22-
interface StateManager {
23-
fun addStreamState(
24-
stream: DestinationStream,
25-
index: Long,
26-
stateMessage: DestinationStateMessage
27-
)
28-
fun addGlobalState(
29-
streamIndexes: List<Pair<DestinationStream, Long>>,
30-
stateMessage: DestinationStateMessage
31-
)
24+
interface StateManager<K, T> {
25+
fun addStreamState(key: K, index: Long, stateMessage: T)
26+
fun addGlobalState(keyIndexes: List<Pair<K, Long>>, stateMessage: T)
3227
fun flushStates()
3328
}
3429

3530
/**
36-
* Destination state manager.
31+
* Message-type agnostic streams state manager.
3732
*
3833
* Accepts global and stream states, and enforces that stream and global state are not mixed.
3934
* Determines ready states by querying the StreamsManager for the state of the record index range
@@ -44,50 +39,73 @@ interface StateManager {
4439
* TODO: Ensure that state is flushed at the end, and require that all state be flushed before the
4540
* destination can succeed.
4641
*/
47-
@Singleton
48-
class DefaultStateManager(
49-
private val catalog: DestinationCatalog,
50-
private val streamsManager: StreamsManager,
51-
private val stateMessageFactory: AirbyteStateMessageFactory,
52-
private val outputConsumer: OutputConsumer
53-
) : StateManager {
42+
abstract class StreamsStateManager<T, U>() : StateManager<DestinationStream, T> {
5443
private val log = KotlinLogging.logger {}
5544

56-
data class GlobalState(
45+
abstract val catalog: DestinationCatalog
46+
abstract val streamsManager: StreamsManager
47+
abstract val outputFactory: MessageConverter<T, U>
48+
abstract val outputConsumer: Consumer<U>
49+
50+
data class GlobalState<T>(
5751
val streamIndexes: List<Pair<DestinationStream, Long>>,
58-
val stateMessage: DestinationStateMessage
52+
val stateMessage: T
5953
)
6054

6155
private val stateIsGlobal: AtomicReference<Boolean?> = AtomicReference(null)
6256
private val streamStates:
63-
ConcurrentHashMap<DestinationStream, LinkedHashMap<Long, DestinationStateMessage>> =
57+
ConcurrentHashMap<DestinationStream, ConcurrentLinkedHashMap<Long, T>> =
6458
ConcurrentHashMap()
65-
private val globalStates: ConcurrentLinkedQueue<GlobalState> = ConcurrentLinkedQueue()
66-
67-
override fun addStreamState(
68-
stream: DestinationStream,
69-
index: Long,
70-
stateMessage: DestinationStateMessage
71-
) {
72-
if (stateIsGlobal.getAndSet(false) != false) {
59+
private val globalStates: ConcurrentLinkedQueue<GlobalState<T>> = ConcurrentLinkedQueue()
60+
61+
override fun addStreamState(key: DestinationStream, index: Long, stateMessage: T) {
62+
if (stateIsGlobal.updateAndGet { it == true } != false) {
7363
throw IllegalStateException("Global state cannot be mixed with non-global state")
7464
}
7565

76-
val streamStates = streamStates.getOrPut(stream) { LinkedHashMap() }
77-
streamStates[index] = stateMessage
78-
log.info { "Added state for stream: $stream at index: $index" }
66+
streamStates.compute(key) { _, indexToMessage ->
67+
val map =
68+
if (indexToMessage == null) {
69+
// If the map doesn't exist yet, build it.
70+
ConcurrentLinkedHashMap.Builder<Long, T>().maximumWeightedCapacity(1000).build()
71+
} else {
72+
if (indexToMessage.isNotEmpty()) {
73+
// Make sure the messages are coming in order
74+
val oldestIndex = indexToMessage.ascendingKeySet().first()
75+
if (oldestIndex > index) {
76+
throw IllegalStateException(
77+
"State message received out of order ($oldestIndex before $index)"
78+
)
79+
}
80+
}
81+
indexToMessage
82+
}
83+
// Actually add the message
84+
map[index] = stateMessage
85+
map
86+
}
87+
88+
log.info { "Added state for stream: $key at index: $index" }
7989
}
8090

81-
override fun addGlobalState(
82-
streamIndexes: List<Pair<DestinationStream, Long>>,
83-
stateMessage: DestinationStateMessage
84-
) {
85-
if (stateIsGlobal.getAndSet(true) != true) {
91+
// TODO: Is it an error if we don't get all the streams every time?
92+
override fun addGlobalState(keyIndexes: List<Pair<DestinationStream, Long>>, stateMessage: T) {
93+
if (stateIsGlobal.updateAndGet { it != false } != true) {
8694
throw IllegalStateException("Global state cannot be mixed with non-global state")
8795
}
8896

89-
globalStates.add(GlobalState(streamIndexes, stateMessage))
90-
log.info { "Added global state with stream indexes: $streamIndexes" }
97+
val head = globalStates.peek()
98+
if (head != null) {
99+
val keyIndexesByStream = keyIndexes.associate { it.first to it.second }
100+
head.streamIndexes.forEach {
101+
if (keyIndexesByStream[it.first]!! < it.second) {
102+
throw IllegalStateException("Global state message received out of order")
103+
}
104+
}
105+
}
106+
107+
globalStates.add(GlobalState(keyIndexes, stateMessage))
108+
log.info { "Added global state with stream indexes: $keyIndexes" }
91109
}
92110

93111
override fun flushStates() {
@@ -105,19 +123,19 @@ class DefaultStateManager(
105123
}
106124

107125
private fun flushGlobalStates() {
108-
if (globalStates.isEmpty()) {
109-
return
110-
}
111-
112-
val head = globalStates.peek()
113-
val allStreamsPersisted =
114-
head.streamIndexes.all { (stream, index) ->
115-
streamsManager.getManager(stream).areRecordsPersistedUntil(index)
126+
while (!globalStates.isEmpty()) {
127+
val head = globalStates.peek()
128+
val allStreamsPersisted =
129+
head.streamIndexes.all { (stream, index) ->
130+
streamsManager.getManager(stream).areRecordsPersistedUntil(index)
131+
}
132+
if (allStreamsPersisted) {
133+
globalStates.poll()
134+
val outMessage = outputFactory.from(head.stateMessage)
135+
outputConsumer.accept(outMessage)
136+
} else {
137+
break
116138
}
117-
if (allStreamsPersisted) {
118-
globalStates.poll()
119-
val outMessage = stateMessageFactory.fromDestinationStateMessage(head.stateMessage)
120-
outputConsumer.accept(outMessage)
121139
}
122140
}
123141

@@ -131,7 +149,7 @@ class DefaultStateManager(
131149
streamStates.remove(index)
132150
?: throw IllegalStateException("State not found for index: $index")
133151
log.info { "Flushing state for stream: $stream at index: $index" }
134-
val outMessage = stateMessageFactory.fromDestinationStateMessage(stateMessage)
152+
val outMessage = outputFactory.from(stateMessage)
135153
outputConsumer.accept(outMessage)
136154
} else {
137155
break
@@ -140,3 +158,11 @@ class DefaultStateManager(
140158
}
141159
}
142160
}
161+
162+
@Singleton
163+
class DefaultStateManager(
164+
override val catalog: DestinationCatalog,
165+
override val streamsManager: StreamsManager,
166+
override val outputFactory: MessageConverter<DestinationStateMessage, AirbyteMessage>,
167+
override val outputConsumer: Consumer<AirbyteMessage>
168+
) : StreamsStateManager<DestinationStateMessage, AirbyteMessage>()

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt

+10-2
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,23 @@ class DefaultStreamManager(
100100
rangesState[batch.batch.state]
101101
?: throw IllegalArgumentException("Invalid batch state: ${batch.batch.state}")
102102

103-
stateRanges.addAll(batch.ranges)
103+
// Force the ranges to overlap at their endpoints, in order to work around
104+
// the behavior of `.encloses`, which otherwise would not consider adjacent ranges as
105+
// contiguous.
106+
// This ensures that a state message received at eg, index 10 (after messages 0..9 have
107+
// been received), will pass `{'[0..5]','[6..9]'}.encloses('[0..10)')`.
108+
val expanded =
109+
batch.ranges.asRanges().map { it.span(Range.singleton(it.upperEndpoint() + 1)) }
110+
111+
stateRanges.addAll(expanded)
104112
log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" }
105113
}
106114

107115
/** True if all records in [0, index] have reached the given state. */
108116
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean {
109117

110118
val completeRanges = rangesState[state]!!
111-
return completeRanges.encloses(Range.closed(0L, index - 1))
119+
return completeRanges.encloses(Range.closedOpen(0L, index))
112120
}
113121

114122
/** True if all records have associated [Batch.State.COMPLETE] batches. */

0 commit comments

Comments
 (0)