@@ -6,34 +6,29 @@ package io.airbyte.cdk.state
6
6
7
7
import io.airbyte.cdk.command.DestinationCatalog
8
8
import io.airbyte.cdk.command.DestinationStream
9
- import io.airbyte.cdk.message.AirbyteStateMessageFactory
10
9
import io.airbyte.cdk.message.DestinationStateMessage
11
- import io.airbyte.cdk.output.OutputConsumer
10
+ import io.airbyte.cdk.message.MessageConverter
11
+ import io.airbyte.protocol.models.v0.AirbyteMessage
12
12
import io.github.oshai.kotlinlogging.KotlinLogging
13
+ import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap
13
14
import jakarta.inject.Singleton
14
15
import java.util.concurrent.ConcurrentHashMap
15
16
import java.util.concurrent.ConcurrentLinkedQueue
16
17
import java.util.concurrent.atomic.AtomicReference
18
+ import java.util.function.Consumer
17
19
18
20
/* *
19
21
* Interface for state management. Should accept stream and global state, as well as requests to
20
22
* flush all data-sufficient states.
21
23
*/
22
- interface StateManager {
23
- fun addStreamState (
24
- stream : DestinationStream ,
25
- index : Long ,
26
- stateMessage : DestinationStateMessage
27
- )
28
- fun addGlobalState (
29
- streamIndexes : List <Pair <DestinationStream , Long >>,
30
- stateMessage : DestinationStateMessage
31
- )
24
+ interface StateManager <K , T > {
25
+ fun addStreamState (key : K , index : Long , stateMessage : T )
26
+ fun addGlobalState (keyIndexes : List <Pair <K , Long >>, stateMessage : T )
32
27
fun flushStates ()
33
28
}
34
29
35
30
/* *
36
- * Destination state manager.
31
+ * Message-type agnostic streams state manager.
37
32
*
38
33
* Accepts global and stream states, and enforces that stream and global state are not mixed.
39
34
* Determines ready states by querying the StreamsManager for the state of the record index range
@@ -44,50 +39,73 @@ interface StateManager {
44
39
* TODO: Ensure that state is flushed at the end, and require that all state be flushed before the
45
40
* destination can succeed.
46
41
*/
47
- @Singleton
48
- class DefaultStateManager (
49
- private val catalog : DestinationCatalog ,
50
- private val streamsManager : StreamsManager ,
51
- private val stateMessageFactory : AirbyteStateMessageFactory ,
52
- private val outputConsumer : OutputConsumer
53
- ) : StateManager {
42
+ abstract class StreamsStateManager <T , U >() : StateManager<DestinationStream, T> {
54
43
private val log = KotlinLogging .logger {}
55
44
56
- data class GlobalState (
45
+ abstract val catalog: DestinationCatalog
46
+ abstract val streamsManager: StreamsManager
47
+ abstract val outputFactory: MessageConverter <T , U >
48
+ abstract val outputConsumer: Consumer <U >
49
+
50
+ data class GlobalState <T >(
57
51
val streamIndexes : List <Pair <DestinationStream , Long >>,
58
- val stateMessage : DestinationStateMessage
52
+ val stateMessage : T
59
53
)
60
54
61
55
private val stateIsGlobal: AtomicReference <Boolean ?> = AtomicReference (null )
62
56
private val streamStates:
63
- ConcurrentHashMap <DestinationStream , LinkedHashMap <Long , DestinationStateMessage >> =
57
+ ConcurrentHashMap <DestinationStream , ConcurrentLinkedHashMap <Long , T >> =
64
58
ConcurrentHashMap ()
65
- private val globalStates: ConcurrentLinkedQueue <GlobalState > = ConcurrentLinkedQueue ()
66
-
67
- override fun addStreamState (
68
- stream : DestinationStream ,
69
- index : Long ,
70
- stateMessage : DestinationStateMessage
71
- ) {
72
- if (stateIsGlobal.getAndSet(false ) != false ) {
59
+ private val globalStates: ConcurrentLinkedQueue <GlobalState <T >> = ConcurrentLinkedQueue ()
60
+
61
+ override fun addStreamState (key : DestinationStream , index : Long , stateMessage : T ) {
62
+ if (stateIsGlobal.updateAndGet { it == true } != false ) {
73
63
throw IllegalStateException (" Global state cannot be mixed with non-global state" )
74
64
}
75
65
76
- val streamStates = streamStates.getOrPut(stream) { LinkedHashMap () }
77
- streamStates[index] = stateMessage
78
- log.info { " Added state for stream: $stream at index: $index " }
66
+ streamStates.compute(key) { _, indexToMessage ->
67
+ val map =
68
+ if (indexToMessage == null ) {
69
+ // If the map doesn't exist yet, build it.
70
+ ConcurrentLinkedHashMap .Builder <Long , T >().maximumWeightedCapacity(1000 ).build()
71
+ } else {
72
+ if (indexToMessage.isNotEmpty()) {
73
+ // Make sure the messages are coming in order
74
+ val oldestIndex = indexToMessage.ascendingKeySet().first()
75
+ if (oldestIndex > index) {
76
+ throw IllegalStateException (
77
+ " State message received out of order ($oldestIndex before $index )"
78
+ )
79
+ }
80
+ }
81
+ indexToMessage
82
+ }
83
+ // Actually add the message
84
+ map[index] = stateMessage
85
+ map
86
+ }
87
+
88
+ log.info { " Added state for stream: $key at index: $index " }
79
89
}
80
90
81
- override fun addGlobalState (
82
- streamIndexes : List <Pair <DestinationStream , Long >>,
83
- stateMessage : DestinationStateMessage
84
- ) {
85
- if (stateIsGlobal.getAndSet(true ) != true ) {
91
+ // TODO: Is it an error if we don't get all the streams every time?
92
+ override fun addGlobalState (keyIndexes : List <Pair <DestinationStream , Long >>, stateMessage : T ) {
93
+ if (stateIsGlobal.updateAndGet { it != false } != true ) {
86
94
throw IllegalStateException (" Global state cannot be mixed with non-global state" )
87
95
}
88
96
89
- globalStates.add(GlobalState (streamIndexes, stateMessage))
90
- log.info { " Added global state with stream indexes: $streamIndexes " }
97
+ val head = globalStates.peek()
98
+ if (head != null ) {
99
+ val keyIndexesByStream = keyIndexes.associate { it.first to it.second }
100
+ head.streamIndexes.forEach {
101
+ if (keyIndexesByStream[it.first]!! < it.second) {
102
+ throw IllegalStateException (" Global state message received out of order" )
103
+ }
104
+ }
105
+ }
106
+
107
+ globalStates.add(GlobalState (keyIndexes, stateMessage))
108
+ log.info { " Added global state with stream indexes: $keyIndexes " }
91
109
}
92
110
93
111
override fun flushStates () {
@@ -105,19 +123,19 @@ class DefaultStateManager(
105
123
}
106
124
107
125
private fun flushGlobalStates () {
108
- if (globalStates.isEmpty()) {
109
- return
110
- }
111
-
112
- val head = globalStates.peek()
113
- val allStreamsPersisted =
114
- head.streamIndexes.all { (stream, index) ->
115
- streamsManager.getManager(stream).areRecordsPersistedUntil(index)
126
+ while (! globalStates.isEmpty()) {
127
+ val head = globalStates.peek()
128
+ val allStreamsPersisted =
129
+ head.streamIndexes.all { (stream, index) ->
130
+ streamsManager.getManager(stream).areRecordsPersistedUntil(index)
131
+ }
132
+ if (allStreamsPersisted) {
133
+ globalStates.poll()
134
+ val outMessage = outputFactory.from(head.stateMessage)
135
+ outputConsumer.accept(outMessage)
136
+ } else {
137
+ break
116
138
}
117
- if (allStreamsPersisted) {
118
- globalStates.poll()
119
- val outMessage = stateMessageFactory.fromDestinationStateMessage(head.stateMessage)
120
- outputConsumer.accept(outMessage)
121
139
}
122
140
}
123
141
@@ -131,7 +149,7 @@ class DefaultStateManager(
131
149
streamStates.remove(index)
132
150
? : throw IllegalStateException (" State not found for index: $index " )
133
151
log.info { " Flushing state for stream: $stream at index: $index " }
134
- val outMessage = stateMessageFactory.fromDestinationStateMessage (stateMessage)
152
+ val outMessage = outputFactory.from (stateMessage)
135
153
outputConsumer.accept(outMessage)
136
154
} else {
137
155
break
@@ -140,3 +158,11 @@ class DefaultStateManager(
140
158
}
141
159
}
142
160
}
161
+
162
+ @Singleton
163
+ class DefaultStateManager (
164
+ override val catalog : DestinationCatalog ,
165
+ override val streamsManager : StreamsManager ,
166
+ override val outputFactory : MessageConverter <DestinationStateMessage , AirbyteMessage >,
167
+ override val outputConsumer : Consumer <AirbyteMessage >
168
+ ) : StreamsStateManager<DestinationStateMessage, AirbyteMessage>()
0 commit comments