Skip to content

Commit 7056428

Browse files
Bulk Load CDK: State -> Checkpoint & flush at end (#45377)
1 parent b00dac8 commit 7056428

File tree

7 files changed

+278
-254
lines changed

7 files changed

+278
-254
lines changed

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt

+19-19
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,36 @@ data class DestinationStreamComplete(
3838
) : DestinationRecordMessage()
3939

4040
/** State. */
41-
sealed class DestinationStateMessage : DestinationMessage() {
41+
sealed class CheckpointMessage : DestinationMessage() {
4242
data class Stats(val recordCount: Long)
43-
data class StreamState(
43+
data class StreamCheckpoint(
4444
val stream: DestinationStream,
4545
val state: JsonNode,
4646
)
4747

4848
abstract val sourceStats: Stats
4949
abstract val destinationStats: Stats?
5050

51-
abstract fun withDestinationStats(stats: Stats): DestinationStateMessage
51+
abstract fun withDestinationStats(stats: Stats): CheckpointMessage
5252
}
5353

54-
data class DestinationStreamState(
55-
val streamState: StreamState,
54+
data class StreamCheckpoint(
55+
val streamCheckpoint: StreamCheckpoint,
5656
override val sourceStats: Stats,
5757
override val destinationStats: Stats? = null
58-
) : DestinationStateMessage() {
58+
) : CheckpointMessage() {
5959
override fun withDestinationStats(stats: Stats) =
60-
DestinationStreamState(streamState, sourceStats, stats)
60+
StreamCheckpoint(streamCheckpoint, sourceStats, stats)
6161
}
6262

63-
data class DestinationGlobalState(
63+
data class GlobalCheckpoint(
6464
val state: JsonNode,
6565
override val sourceStats: Stats,
6666
override val destinationStats: Stats? = null,
67-
val streamStates: List<StreamState> = emptyList()
68-
) : DestinationStateMessage() {
67+
val streamCheckpoints: List<StreamCheckpoint> = emptyList()
68+
) : CheckpointMessage() {
6969
override fun withDestinationStats(stats: Stats) =
70-
DestinationGlobalState(state, sourceStats, stats, streamStates)
70+
GlobalCheckpoint(state, sourceStats, stats, streamCheckpoints)
7171
}
7272

7373
/** Catchall for anything unimplemented. */
@@ -108,21 +108,21 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
108108
AirbyteMessage.Type.STATE -> {
109109
when (message.state.type) {
110110
AirbyteStateMessage.AirbyteStateType.STREAM ->
111-
DestinationStreamState(
112-
streamState = fromAirbyteStreamState(message.state.stream),
111+
StreamCheckpoint(
112+
streamCheckpoint = fromAirbyteStreamState(message.state.stream),
113113
sourceStats =
114-
DestinationStateMessage.Stats(
114+
CheckpointMessage.Stats(
115115
recordCount = message.state.sourceStats.recordCount.toLong()
116116
)
117117
)
118118
AirbyteStateMessage.AirbyteStateType.GLOBAL ->
119-
DestinationGlobalState(
119+
GlobalCheckpoint(
120120
sourceStats =
121-
DestinationStateMessage.Stats(
121+
CheckpointMessage.Stats(
122122
recordCount = message.state.sourceStats.recordCount.toLong()
123123
),
124124
state = message.state.global.sharedState,
125-
streamStates =
125+
streamCheckpoints =
126126
message.state.global.streamStates.map { fromAirbyteStreamState(it) }
127127
)
128128
else -> // TODO: Do we still need to handle LEGACY?
@@ -135,9 +135,9 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
135135

136136
private fun fromAirbyteStreamState(
137137
streamState: AirbyteStreamState
138-
): DestinationStateMessage.StreamState {
138+
): CheckpointMessage.StreamCheckpoint {
139139
val descriptor = streamState.streamDescriptor
140-
return DestinationStateMessage.StreamState(
140+
return CheckpointMessage.StreamCheckpoint(
141141
stream = catalog.getStream(namespace = descriptor.namespace, name = descriptor.name),
142142
state = streamState.streamState
143143
)

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt

+15-13
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
1313
import jakarta.inject.Singleton
1414

1515
/**
16-
* Converts the internal @[DestinationStateMessage] case class to the Protocol state messages
17-
* required by @[io.airbyte.cdk.output.OutputConsumer]
16+
* Converts the internal @[CheckpointMessage] case class to the Protocol state messages required by
17+
* @[io.airbyte.cdk.output.OutputConsumer]
1818
*/
1919
interface MessageConverter<T, U> {
2020
fun from(message: T): U
2121
}
2222

2323
@Singleton
24-
class DefaultMessageConverter : MessageConverter<DestinationStateMessage, AirbyteMessage> {
25-
override fun from(message: DestinationStateMessage): AirbyteMessage {
24+
class DefaultMessageConverter : MessageConverter<CheckpointMessage, AirbyteMessage> {
25+
override fun from(message: CheckpointMessage): AirbyteMessage {
2626
val state =
2727
when (message) {
28-
is DestinationStreamState ->
28+
is StreamCheckpoint ->
2929
AirbyteStateMessage()
3030
.withSourceStats(
3131
AirbyteStateStats()
@@ -40,8 +40,8 @@ class DefaultMessageConverter : MessageConverter<DestinationStateMessage, Airbyt
4040
)
4141
)
4242
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)
43-
.withStream(fromStreamState(message.streamState))
44-
is DestinationGlobalState ->
43+
.withStream(fromStreamState(message.streamCheckpoint))
44+
is GlobalCheckpoint ->
4545
AirbyteStateMessage()
4646
.withSourceStats(
4747
AirbyteStateStats()
@@ -56,21 +56,23 @@ class DefaultMessageConverter : MessageConverter<DestinationStateMessage, Airbyt
5656
.withGlobal(
5757
AirbyteGlobalState()
5858
.withSharedState(message.state)
59-
.withStreamStates(message.streamStates.map { fromStreamState(it) })
59+
.withStreamStates(
60+
message.streamCheckpoints.map { fromStreamState(it) }
61+
)
6062
)
6163
}
62-
return AirbyteMessage().withState(state)
64+
return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(state)
6365
}
6466

6567
private fun fromStreamState(
66-
streamState: DestinationStateMessage.StreamState
68+
streamCheckpoint: CheckpointMessage.StreamCheckpoint
6769
): AirbyteStreamState {
6870
return AirbyteStreamState()
6971
.withStreamDescriptor(
7072
StreamDescriptor()
71-
.withNamespace(streamState.stream.descriptor.namespace)
72-
.withName(streamState.stream.descriptor.name)
73+
.withNamespace(streamCheckpoint.stream.descriptor.namespace)
74+
.withName(streamCheckpoint.stream.descriptor.name)
7375
)
74-
.withStreamState(streamState.state)
76+
.withStreamState(streamCheckpoint.state)
7577
}
7678
}

airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt

+15-13
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ package io.airbyte.cdk.message
77
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
88
import io.airbyte.cdk.command.DestinationCatalog
99
import io.airbyte.cdk.command.DestinationStream
10-
import io.airbyte.cdk.state.StateManager
10+
import io.airbyte.cdk.state.CheckpointManager
1111
import io.airbyte.cdk.state.StreamsManager
1212
import jakarta.inject.Singleton
1313

@@ -18,7 +18,7 @@ interface MessageQueueWriter<T : Any> {
1818

1919
/**
2020
* Routes @[DestinationRecordMessage]s by stream to the appropriate channel and @
21-
* [DestinationStateMessage]s to the state manager.
21+
* [CheckpointMessage]s to the state manager.
2222
*
2323
* TODO: Handle other message types.
2424
*/
@@ -31,7 +31,7 @@ class DestinationMessageQueueWriter(
3131
private val catalog: DestinationCatalog,
3232
private val messageQueue: MessageQueue<DestinationStream, DestinationRecordWrapped>,
3333
private val streamsManager: StreamsManager,
34-
private val stateManager: StateManager<DestinationStream, DestinationStateMessage>
34+
private val checkpointManager: CheckpointManager<DestinationStream, CheckpointMessage>
3535
) : MessageQueueWriter<DestinationMessage> {
3636
/**
3737
* Deserialize and route the message to the appropriate channel.
@@ -62,28 +62,30 @@ class DestinationMessageQueueWriter(
6262
}
6363
}
6464
}
65-
is DestinationStateMessage -> {
65+
is CheckpointMessage -> {
6666
when (message) {
6767
/**
6868
* For a stream state message, mark the checkpoint and add the message with
6969
* index and count to the state manager. Also, add the count to the destination
7070
* stats.
7171
*/
72-
is DestinationStreamState -> {
73-
val stream = message.streamState.stream
72+
is StreamCheckpoint -> {
73+
val stream = message.streamCheckpoint.stream
7474
val manager = streamsManager.getManager(stream)
7575
val (currentIndex, countSinceLast) = manager.markCheckpoint()
7676
val messageWithCount =
77-
message.withDestinationStats(
78-
DestinationStateMessage.Stats(countSinceLast)
79-
)
80-
stateManager.addStreamState(stream, currentIndex, messageWithCount)
77+
message.withDestinationStats(CheckpointMessage.Stats(countSinceLast))
78+
checkpointManager.addStreamCheckpoint(
79+
stream,
80+
currentIndex,
81+
messageWithCount
82+
)
8183
}
8284
/**
8385
* For a global state message, collect the index per stream, but add the total
8486
* count to the destination stats.
8587
*/
86-
is DestinationGlobalState -> {
88+
is GlobalCheckpoint -> {
8789
val streamWithIndexAndCount =
8890
catalog.streams.map { stream ->
8991
val manager = streamsManager.getManager(stream)
@@ -92,9 +94,9 @@ class DestinationMessageQueueWriter(
9294
}
9395
val totalCount = streamWithIndexAndCount.sumOf { it.third }
9496
val messageWithCount =
95-
message.withDestinationStats(DestinationStateMessage.Stats(totalCount))
97+
message.withDestinationStats(CheckpointMessage.Stats(totalCount))
9698
val streamIndexes = streamWithIndexAndCount.map { it.first to it.second }
97-
stateManager.addGlobalState(streamIndexes, messageWithCount)
99+
checkpointManager.addGlobalCheckpoint(streamIndexes, messageWithCount)
98100
}
99101
}
100102
}

0 commit comments

Comments
 (0)