Skip to content

🎉 Destination snowflake: reduce memory consumption #10297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 15, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -81,13 +79,13 @@ public class BufferedStreamConsumer extends FailureTrackingAirbyteMessageConsume
private final RecordWriter recordWriter;
private final CheckedConsumer<Boolean, Exception> onClose;
private final Set<AirbyteStreamNameNamespacePair> streamNames;
private final List<AirbyteMessage> buffer;
private final ConfiguredAirbyteCatalog catalog;
private final CheckedFunction<JsonNode, Boolean, Exception> isValidRecord;
private final Map<AirbyteStreamNameNamespacePair, Long> pairToIgnoredRecordCount;
private final Map<AirbyteStreamNameNamespacePair, Long> streamToIgnoredRecordCount;
private final Consumer<AirbyteMessage> outputRecordCollector;
private final long maxQueueSizeInBytes;
private long bufferSizeInBytes;
private Map<AirbyteStreamNameNamespacePair, List<AirbyteRecordMessage>> streamBuffer;

private boolean hasStarted;
private boolean hasClosed;
Expand All @@ -112,9 +110,9 @@ public BufferedStreamConsumer(final Consumer<AirbyteMessage> outputRecordCollect
this.catalog = catalog;
this.streamNames = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog);
this.isValidRecord = isValidRecord;
this.buffer = new ArrayList<>(10_000);
this.bufferSizeInBytes = 0;
this.pairToIgnoredRecordCount = new HashMap<>();
this.streamToIgnoredRecordCount = new HashMap<>();
this.streamBuffer = new HashMap<>();
}

@Override
Expand All @@ -123,7 +121,7 @@ protected void startTracked() throws Exception {
Preconditions.checkState(!hasStarted, "Consumer has already been started.");
hasStarted = true;

pairToIgnoredRecordCount.clear();
streamToIgnoredRecordCount.clear();
LOGGER.info("{} started.", BufferedStreamConsumer.class);

onStart.call();
Expand All @@ -141,7 +139,7 @@ protected void acceptTracked(final AirbyteMessage message) throws Exception {
}

if (!isValidRecord.apply(message.getRecord().getData())) {
pairToIgnoredRecordCount.put(stream, pairToIgnoredRecordCount.getOrDefault(stream, 0L) + 1L);
streamToIgnoredRecordCount.put(stream, streamToIgnoredRecordCount.getOrDefault(stream, 0L) + 1L);
return;
}

Expand All @@ -151,15 +149,12 @@ protected void acceptTracked(final AirbyteMessage message) throws Exception {
final long messageSizeInBytes = ByteUtils.getSizeInBytesForUTF8CharSet(Jsons.serialize(recordMessage.getData()));
if (bufferSizeInBytes + messageSizeInBytes > maxQueueSizeInBytes) {
LOGGER.info("Flushing buffer...");
AirbyteSentry.executeWithTracing("FlushBuffer",
this::flushQueueToDestination,
Map.of("stream", stream.getName(),
"namespace", Objects.requireNonNullElse(stream.getNamespace(), "null"),
"bufferSizeInBytes", bufferSizeInBytes));
flushQueueToDestination(bufferSizeInBytes);
bufferSizeInBytes = 0;
}

buffer.add(message);
final List<AirbyteRecordMessage> bufferedRecords = streamBuffer.computeIfAbsent(stream, k -> new ArrayList<>());
bufferedRecords.add(message.getRecord());
bufferSizeInBytes += messageSizeInBytes;

} else if (message.getType() == Type.STATE) {
Expand All @@ -170,16 +165,13 @@ protected void acceptTracked(final AirbyteMessage message) throws Exception {

}

private void flushQueueToDestination() throws Exception {
final Map<AirbyteStreamNameNamespacePair, List<AirbyteRecordMessage>> recordsByStream = buffer.stream()
.map(AirbyteMessage::getRecord)
.collect(Collectors.groupingBy(AirbyteStreamNameNamespacePair::fromRecordMessage));

buffer.clear();

for (final Map.Entry<AirbyteStreamNameNamespacePair, List<AirbyteRecordMessage>> entry : recordsByStream.entrySet()) {
recordWriter.accept(entry.getKey(), entry.getValue());
}
private void flushQueueToDestination(long bufferSizeInBytes) throws Exception {
AirbyteSentry.executeWithTracing("FlushBuffer", () -> {
for (final Map.Entry<AirbyteStreamNameNamespacePair, List<AirbyteRecordMessage>> entry : streamBuffer.entrySet()) {
recordWriter.accept(entry.getKey(), entry.getValue());
}
}, Map.of("bufferSizeInBytes", bufferSizeInBytes));
streamBuffer = new HashMap<>();

if (pendingState != null) {
lastFlushedState = pendingState;
Expand All @@ -199,13 +191,13 @@ protected void close(final boolean hasFailed) throws Exception {
Preconditions.checkState(!hasClosed, "Has already closed.");
hasClosed = true;

pairToIgnoredRecordCount
.forEach((pair, count) -> LOGGER.warn("A total of {} record(s) of data from stream {} were invalid and were ignored.", count, pair));
streamToIgnoredRecordCount.forEach((pair, count) ->
LOGGER.warn("A total of {} record(s) of data from stream {} were invalid and were ignored.", count, pair));
if (hasFailed) {
LOGGER.error("executing on failed close procedure.");
} else {
LOGGER.info("executing on success close procedure.");
flushQueueToDestination();
flushQueueToDestination(bufferSizeInBytes);
}

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.function.Function;
import java.util.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataAdapter {

private static final Logger LOGGER = LoggerFactory.getLogger(DataAdapter.class);

private final Predicate<JsonNode> filterValueNode;
private final Function<JsonNode, JsonNode> valueNodeAdapter;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class JdbcSqlOperations implements SqlOperations {

private static final Logger LOGGER = LoggerFactory.getLogger(JdbcSqlOperations.class);
protected static final String SHOW_SCHEMAS = "show schemas;";
protected static final String NAME = "name";

Expand Down Expand Up @@ -63,21 +61,14 @@ public String createTableQuery(final JdbcDatabase database, final String schemaN
}

protected void writeBatchToFile(final File tmpFile, final List<AirbyteRecordMessage> records) throws Exception {
PrintWriter writer = null;
try {
writer = new PrintWriter(tmpFile, StandardCharsets.UTF_8);
final var csvPrinter = new CSVPrinter(writer, CSVFormat.DEFAULT);

try (final PrintWriter writer = new PrintWriter(tmpFile, StandardCharsets.UTF_8);
final CSVPrinter csvPrinter = new CSVPrinter(writer, CSVFormat.DEFAULT)) {
for (final AirbyteRecordMessage record : records) {
final var uuid = UUID.randomUUID().toString();
final var jsonData = Jsons.serialize(formatData(record.getData()));
final var emittedAt = Timestamp.from(Instant.ofEpochMilli(record.getEmittedAt()));
csvPrinter.printRecord(uuid, jsonData, emittedAt);
}
} finally {
if (writer != null) {
writer.close();
}
}
}

Expand Down Expand Up @@ -137,7 +128,8 @@ public final void insertRecords(final JdbcDatabase database,
throws Exception {
AirbyteSentry.executeWithTracing("InsertRecords",
() -> {
records.forEach(airbyteRecordMessage -> getDataAdapter().adapt(airbyteRecordMessage.getData()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bad. For every record message, a new data adapter object is created (in the getDataAdapter method).

Copy link
Contributor

@edgao edgao Feb 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obviously this is already way better, but could we just dump the data adapter into a field rather than constructing one per record batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point! Done.

final Optional<DataAdapter> dataAdapter = getDataAdapter();
dataAdapter.ifPresent(adapter -> records.forEach(airbyteRecordMessage -> adapter.adapt(airbyteRecordMessage.getData())));
insertRecordsInternal(database, records, schemaName, tableName);
},
Map.of("schema", Objects.requireNonNullElse(schemaName, "null"), "table", tableName, "recordCount", records.size()));
Expand All @@ -149,8 +141,8 @@ protected abstract void insertRecordsInternal(JdbcDatabase database,
String tableName)
throws Exception;

protected DataAdapter getDataAdapter() {
return new DataAdapter(j -> false, c -> c);
protected Optional<DataAdapter> getDataAdapter() {
return Optional.empty();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
import java.nio.file.Files;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import org.postgresql.copy.CopyManager;
import org.postgresql.core.BaseConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PostgresSqlOperations extends JdbcSqlOperations {

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresSqlOperations.class);

@Override
public void insertRecordsInternal(final JdbcDatabase database,
final List<AirbyteRecordMessage> records,
Expand Down Expand Up @@ -59,8 +56,8 @@ public void insertRecordsInternal(final JdbcDatabase database,
}

@Override
protected DataAdapter getDataAdapter() {
return new PostgresDataAdapter();
protected Optional<DataAdapter> getDataAdapter() {
return Optional.of(new PostgresDataAdapter());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@
import io.airbyte.integrations.base.IntegrationRunner;
import io.airbyte.integrations.destination.jdbc.copy.SwitchingDestination;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeDestination extends SwitchingDestination<SnowflakeDestination.DestinationType> {

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeDestination.class);

enum DestinationType {
COPY_S3,
COPY_GCS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class SnowflakeInternalStagingConsumerFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeInternalStagingConsumerFactory.class);

private static final long MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4; // 256mb
private static final long MAX_BATCH_SIZE_BYTES = 128 * 1024 * 1024; // 128mb
private final String CURRENT_SYNC_PATH = UUID.randomUUID().toString();

public AirbyteMessageConsumer create(final Consumer<AirbyteMessage> outputRecordCollector,
Expand Down