Skip to content

Destinations CDK: CatalogParser sets default namespace #38121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ corresponds to that version.

| Version | Date | Pull Request | Subject |
|:--------|:-----------|:-----------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.37.0 | 2024-06-10 | [\#38121](https://github.com/airbytehq/airbyte/pull/38121) | Destinations: Set default namespace via CatalogParser |
| 0.36.8 | 2024-06-07 | [\#38763](https://github.com/airbytehq/airbyte/pull/38763) | Increase Jackson message length limit |
| 0.36.7 | 2024-06-06 | [\#39220](https://github.com/airbytehq/airbyte/pull/39220) | Handle null messages in ConnectorExceptionUtil |
| 0.36.6 | 2024-06-05 | [\#39106](https://github.com/airbytehq/airbyte/pull/39106) | Skip write to storage with 0 byte file |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.integrations.destination.async

import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.buffers.BufferEnqueue
Expand All @@ -28,7 +27,6 @@ import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import kotlin.jvm.optionals.getOrNull
import org.jetbrains.annotations.VisibleForTesting

private val logger = KotlinLogging.logger {}
Expand All @@ -51,7 +49,6 @@ constructor(
onFlush: DestinationFlushFunction,
private val catalog: ConfiguredAirbyteCatalog,
private val bufferManager: BufferManager,
private val defaultNamespace: Optional<String>,
private val flushFailure: FlushFailure = FlushFailure(),
workerPool: ExecutorService = Executors.newFixedThreadPool(5),
private val airbyteMessageDeserializer: AirbyteMessageDeserializer =
Expand Down Expand Up @@ -79,28 +76,6 @@ constructor(
private var hasClosed = false
private var hasFailed = false

internal constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
flushFailure: FlushFailure,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
defaultNamespace,
flushFailure,
Executors.newFixedThreadPool(5),
AirbyteMessageDeserializer(),
)

@Throws(Exception::class)
override fun start() {
Preconditions.checkState(!hasStarted, "Consumer has already been started.")
Expand Down Expand Up @@ -129,9 +104,6 @@ constructor(
message,
)
if (AirbyteMessage.Type.RECORD == partialAirbyteMessage.type) {
if (Strings.isNullOrEmpty(partialAirbyteMessage.record?.namespace)) {
partialAirbyteMessage.record?.namespace = defaultNamespace.getOrNull()
}
validateRecord(partialAirbyteMessage)

partialAirbyteMessage.record?.streamDescriptor?.let {
Expand All @@ -141,7 +113,6 @@ constructor(
bufferEnqueue.addRecord(
partialAirbyteMessage,
sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES,
defaultNamespace,
)
}

Expand All @@ -159,10 +130,14 @@ constructor(
bufferManager.close()

val streamSyncSummaries =
streamNames.associateWith { streamDescriptor: StreamDescriptor ->
StreamSyncSummary(
Optional.of(getRecordCounter(streamDescriptor).get()),
)
streamNames.associate { streamDescriptor ->
StreamDescriptorUtils.withDefaultNamespace(
streamDescriptor,
bufferManager.defaultNamespace,
) to
StreamSyncSummary(
Optional.of(getRecordCounter(streamDescriptor).get()),
)
}
onClose.accept(hasFailed, streamSyncSummaries)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,11 @@ object StreamDescriptorUtils {

return pairs
}

fun withDefaultNamespace(sd: StreamDescriptor, defaultNamespace: String) =
if (sd.namespace.isNullOrEmpty()) {
StreamDescriptor().withName(sd.name).withNamespace(defaultNamespace)
} else {
sd
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ package io.airbyte.cdk.integrations.destination.async.buffers
import io.airbyte.cdk.integrations.destination.async.GlobalMemoryManager
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.cdk.integrations.destination.async.state.GlobalAsyncStateManager
import io.airbyte.commons.json.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.StreamDescriptor
import java.util.Optional
import java.util.concurrent.ConcurrentMap

/**
Expand All @@ -20,6 +20,7 @@ class BufferEnqueue(
private val memoryManager: GlobalMemoryManager,
private val buffers: ConcurrentMap<StreamDescriptor, StreamAwareQueue>,
private val stateManager: GlobalAsyncStateManager,
private val defaultNamespace: String,
) {
/**
* Buffer a record. Contains memory management logic to dynamically adjust queue size based via
Expand All @@ -31,12 +32,11 @@ class BufferEnqueue(
fun addRecord(
message: PartialAirbyteMessage,
sizeInBytes: Int,
defaultNamespace: Optional<String>,
) {
if (message.type == AirbyteMessage.Type.RECORD) {
handleRecord(message, sizeInBytes)
} else if (message.type == AirbyteMessage.Type.STATE) {
stateManager.trackState(message, sizeInBytes.toLong(), defaultNamespace.orElse(""))
stateManager.trackState(message, sizeInBytes.toLong())
}
}

Expand All @@ -53,15 +53,28 @@ class BufferEnqueue(
}
val stateId = stateManager.getStateIdAndIncrementCounter(streamDescriptor)

var addedToQueue = queue.offer(message, sizeInBytes.toLong(), stateId)
// We don't set the default namespace until after putting this message into the state
// manager/etc.
// All our internal handling is on the true (null) namespace,
// we just set the default namespace when handing off to destination-specific code.
val mangledMessage =
if (message.record!!.namespace.isNullOrEmpty()) {
val clone = Jsons.clone(message)
clone.record!!.namespace = defaultNamespace
clone
} else {
message
}

var addedToQueue = queue.offer(mangledMessage, sizeInBytes.toLong(), stateId)

var i = 0
while (!addedToQueue) {
val newlyAllocatedMemory = memoryManager.requestMemory()
if (newlyAllocatedMemory > 0) {
queue.addMaxMemory(newlyAllocatedMemory)
}
addedToQueue = queue.offer(message, sizeInBytes.toLong(), stateId)
addedToQueue = queue.offer(mangledMessage, sizeInBytes.toLong(), stateId)
i++
if (i > 5) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ private val logger = KotlinLogging.logger {}
class BufferManager
@JvmOverloads
constructor(
/**
* This probably doesn't belong here, but it's the easiest place where both [BufferEnqueue] and
* [io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer] can both get to it.
*/
public val defaultNamespace: String,
maxMemory: Long = (Runtime.getRuntime().maxMemory() * MEMORY_LIMIT_RATIO).toLong(),
) {
@get:VisibleForTesting val buffers: ConcurrentMap<StreamDescriptor, StreamAwareQueue>
Expand All @@ -46,7 +51,7 @@ constructor(
memoryManager = GlobalMemoryManager(maxMemory)
this.stateManager = GlobalAsyncStateManager(memoryManager)
buffers = ConcurrentHashMap()
bufferEnqueue = BufferEnqueue(memoryManager, buffers, stateManager)
bufferEnqueue = BufferEnqueue(memoryManager, buffers, stateManager, defaultNamespace)
bufferDequeue = BufferDequeue(memoryManager, buffers, stateManager)
debugLoop = Executors.newSingleThreadScheduledExecutor()
debugLoop.scheduleAtFixedRate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.integrations.destination.async.state

import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.destination.async.GlobalMemoryManager
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.commons.json.Jsons
Expand Down Expand Up @@ -104,7 +103,6 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
fun trackState(
message: PartialAirbyteMessage,
sizeInBytes: Long,
defaultNamespace: String,
) {
if (preState) {
convertToGlobalIfNeeded(message)
Expand All @@ -113,7 +111,7 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
// stateType should not change after a conversion.
Preconditions.checkArgument(stateType == extractStateType(message))

closeState(message, sizeInBytes, defaultNamespace)
closeState(message, sizeInBytes)
}

/**
Expand Down Expand Up @@ -323,10 +321,9 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
private fun closeState(
message: PartialAirbyteMessage,
sizeInBytes: Long,
defaultNamespace: String,
) {
val resolvedDescriptor: StreamDescriptor =
extractStream(message, defaultNamespace)
extractStream(message)
.orElse(
SENTINEL_GLOBAL_DESC,
)
Expand Down Expand Up @@ -424,38 +421,14 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
UUID.randomUUID().toString(),
)

/**
* If the user has selected the Destination Namespace as the Destination default while
* setting up the connector, the platform sets the namespace as null in the StreamDescriptor
* in the AirbyteMessages (both record and state messages). The destination checks that if
* the namespace is empty or null, if yes then re-populates it with the defaultNamespace.
* See [io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer.accept] But
* destination only does this for the record messages. So when state messages arrive without
* a namespace and since the destination doesn't repopulate it with the default namespace,
* there is a mismatch between the StreamDescriptor from record messages and state messages.
* That breaks the logic of the state management class as [descToStateIdQ] needs to have
* consistent StreamDescriptor. This is why while trying to extract the StreamDescriptor
* from state messages, we check if the namespace is null, if yes then replace it with
* defaultNamespace to keep it consistent with the record messages.
*/
private fun extractStream(
message: PartialAirbyteMessage,
defaultNamespace: String,
): Optional<StreamDescriptor> {
if (
message.state?.type != null &&
message.state?.type == AirbyteStateMessage.AirbyteStateType.STREAM
) {
val streamDescriptor: StreamDescriptor? = message.state?.stream?.streamDescriptor
if (Strings.isNullOrEmpty(streamDescriptor?.namespace)) {
return Optional.of(
StreamDescriptor()
.withName(
streamDescriptor?.name,
)
.withNamespace(defaultNamespace),
)
}
return streamDescriptor?.let { Optional.of(it) } ?: Optional.empty()
}
return Optional.empty()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.36.8
version=0.37.0
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
import java.io.IOException
import java.math.BigDecimal
import java.time.Instant
import java.util.Optional
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
Expand Down Expand Up @@ -60,7 +59,7 @@ class AsyncStreamConsumerTest {
private val CATALOG: ConfiguredAirbyteCatalog =
ConfiguredAirbyteCatalog()
.withStreams(
java.util.List.of(
listOf(
CatalogHelpers.createConfiguredAirbyteStream(
STREAM_NAME,
SCHEMA_NAME,
Expand Down Expand Up @@ -145,9 +144,8 @@ class AsyncStreamConsumerTest {
onClose = onClose,
onFlush = flushFunction,
catalog = CATALOG,
bufferManager = BufferManager(),
bufferManager = BufferManager("default_ns"),
flushFailure = flushFailure,
defaultNamespace = Optional.of("default_ns"),
airbyteMessageDeserializer = airbyteMessageDeserializer,
workerPool = Executors.newFixedThreadPool(5),
)
Expand Down Expand Up @@ -264,9 +262,8 @@ class AsyncStreamConsumerTest {
Mockito.mock(OnCloseFunction::class.java),
flushFunction,
CATALOG,
BufferManager((1024 * 10).toLong()),
BufferManager("default_ns", (1024 * 10).toLong()),
flushFailure,
Optional.of("default_ns"),
)
Mockito.`when`(flushFunction.optimalBatchSizeBytes).thenReturn(0L)

Expand Down
Loading
Loading