14
14
import io .airbyte .cdk .integrations .destination .jdbc .ColumnDefinition ;
15
15
import io .airbyte .cdk .integrations .destination .jdbc .TableDefinition ;
16
16
import io .airbyte .cdk .integrations .destination .jdbc .typing_deduping .JdbcDestinationHandler ;
17
+ import io .airbyte .commons .json .Jsons ;
17
18
import io .airbyte .integrations .base .destination .typing_deduping .AirbyteProtocolType ;
18
19
import io .airbyte .integrations .base .destination .typing_deduping .AirbyteType ;
19
20
import io .airbyte .integrations .base .destination .typing_deduping .Array ;
20
21
import io .airbyte .integrations .base .destination .typing_deduping .ColumnId ;
21
- import io .airbyte .integrations .base .destination .typing_deduping .DestinationInitialState ;
22
- import io .airbyte .integrations .base .destination .typing_deduping .DestinationInitialStateImpl ;
23
- import io .airbyte .integrations .base .destination .typing_deduping .InitialRawTableState ;
22
+ import io .airbyte .integrations .base .destination .typing_deduping .DestinationInitialStatus ;
23
+ import io .airbyte .integrations .base .destination .typing_deduping .InitialRawTableStatus ;
24
24
import io .airbyte .integrations .base .destination .typing_deduping .Sql ;
25
25
import io .airbyte .integrations .base .destination .typing_deduping .StreamConfig ;
26
26
import io .airbyte .integrations .base .destination .typing_deduping .StreamId ;
27
27
import io .airbyte .integrations .base .destination .typing_deduping .Struct ;
28
28
import io .airbyte .integrations .base .destination .typing_deduping .Union ;
29
29
import io .airbyte .integrations .base .destination .typing_deduping .UnsupportedOneOf ;
30
+ import io .airbyte .integrations .destination .snowflake .typing_deduping .migrations .SnowflakeState ;
31
+ import io .airbyte .protocol .models .v0 .AirbyteStreamNameNamespacePair ;
32
+ import io .airbyte .protocol .models .v0 .DestinationSyncMode ;
30
33
import java .sql .ResultSet ;
31
34
import java .sql .SQLException ;
32
35
import java .time .Instant ;
40
43
import java .util .stream .Collectors ;
41
44
import net .snowflake .client .jdbc .SnowflakeSQLException ;
42
45
import org .apache .commons .text .StringSubstitutor ;
46
+ import org .jooq .SQLDialect ;
43
47
import org .slf4j .Logger ;
44
48
import org .slf4j .LoggerFactory ;
45
49
46
- public class SnowflakeDestinationHandler extends JdbcDestinationHandler {
50
+ public class SnowflakeDestinationHandler extends JdbcDestinationHandler < SnowflakeState > {
47
51
48
52
private static final Logger LOGGER = LoggerFactory .getLogger (SnowflakeDestinationHandler .class );
49
53
public static final String EXCEPTION_COMMON_PREFIX = "JavaScript execution error: Uncaught Execution of multiple statements failed on statement" ;
50
54
51
55
private final String databaseName ;
52
56
private final JdbcDatabase database ;
53
57
54
- public SnowflakeDestinationHandler (final String databaseName , final JdbcDatabase database ) {
55
- super (databaseName , database );
56
- this .databaseName = databaseName ;
58
+ public SnowflakeDestinationHandler (final String databaseName , final JdbcDatabase database , final String rawTableSchema ) {
59
+ // Postgres is close enough to Snowflake SQL for our purposes.
60
+ super (databaseName , database , rawTableSchema , SQLDialect .POSTGRES );
61
+ // We don't quote the database name in any queries, so just upcase it.
62
+ this .databaseName = databaseName .toUpperCase ();
57
63
this .database = database ;
58
64
}
59
65
@@ -107,7 +113,7 @@ AND table_schema IN (%s)
107
113
AND table_name IN (%s)
108
114
""" .formatted (paramHolder , paramHolder );
109
115
final String [] bindValues = new String [streamIds .size () * 2 + 1 ];
110
- bindValues [0 ] = databaseName . toUpperCase () ;
116
+ bindValues [0 ] = databaseName ;
111
117
System .arraycopy (namespaces , 0 , bindValues , 1 , namespaces .length );
112
118
System .arraycopy (names , 0 , bindValues , namespaces .length + 1 , names .length );
113
119
final List <JsonNode > results = database .queryJsons (query , bindValues );
@@ -120,14 +126,18 @@ AND table_name IN (%s)
120
126
return tableRowCounts ;
121
127
}
122
128
123
- public InitialRawTableState getInitialRawTableState (final StreamId id ) throws Exception {
129
+ private InitialRawTableStatus getInitialRawTableState (final StreamId id , final DestinationSyncMode destinationSyncMode ) throws Exception {
130
+ // Short-circuit for overwrite, table will be truncated anyway
131
+ if (destinationSyncMode == DestinationSyncMode .OVERWRITE ) {
132
+ return new InitialRawTableStatus (false , false , Optional .empty ());
133
+ }
124
134
final ResultSet tables = database .getMetaData ().getTables (
125
135
databaseName ,
126
136
id .rawNamespace (),
127
137
id .rawName (),
128
138
null );
129
139
if (!tables .next ()) {
130
- return new InitialRawTableState ( false , Optional .empty ());
140
+ return new InitialRawTableStatus ( false , false , Optional .empty ());
131
141
}
132
142
// Snowflake timestamps have nanosecond precision, so decrement by 1ns
133
143
// And use two explicit queries because COALESCE doesn't short-circuit.
@@ -136,33 +146,55 @@ public InitialRawTableState getInitialRawTableState(final StreamId id) throws Ex
136
146
conn -> conn .createStatement ().executeQuery (new StringSubstitutor (Map .of (
137
147
"raw_table" , id .rawTableId (SnowflakeSqlGenerator .QUOTE ))).replace (
138
148
"""
139
- SELECT to_varchar(
140
- TIMESTAMPADD(NANOSECOND, -1, MIN("_airbyte_extracted_at")),
141
- 'YYYY-MM-DDTHH24:MI:SS.FF9TZH:TZM'
142
- ) AS MIN_TIMESTAMP
143
- FROM ${raw_table}
144
- WHERE "_airbyte_loaded_at" IS NULL
149
+ WITH MIN_TS AS (
150
+ SELECT TIMESTAMPADD(NANOSECOND, -1,
151
+ MIN(TIMESTAMPADD(
152
+ HOUR,
153
+ EXTRACT(timezone_hour from "_airbyte_extracted_at"),
154
+ TIMESTAMPADD(
155
+ MINUTE,
156
+ EXTRACT(timezone_minute from "_airbyte_extracted_at"),
157
+ CONVERT_TIMEZONE('UTC', "_airbyte_extracted_at")
158
+ )
159
+ ))) AS MIN_TIMESTAMP
160
+ FROM ${raw_table}
161
+ WHERE "_airbyte_loaded_at" IS NULL
162
+ ) SELECT TO_VARCHAR(MIN_TIMESTAMP,'YYYY-MM-DDTHH24:MI:SS.FF9TZH:TZM') as MIN_TIMESTAMP_UTC from MIN_TS;
145
163
""" )),
146
164
// The query will always return exactly one record, so use .get(0)
147
- record -> record .getString ("MIN_TIMESTAMP " )).get (0 ));
165
+ record -> record .getString ("MIN_TIMESTAMP_UTC " )).get (0 ));
148
166
if (minUnloadedTimestamp .isPresent ()) {
149
- return new InitialRawTableState ( true , minUnloadedTimestamp .map (Instant ::parse ));
167
+ return new InitialRawTableStatus ( true , true , minUnloadedTimestamp .map (Instant ::parse ));
150
168
}
151
169
152
170
// If there are no unloaded raw records, then we can safely skip all existing raw records.
153
171
// This second query just finds the newest raw record.
172
+
173
+ // This is _technically_ wrong, because during the DST transition we might select
174
+ // the wrong max timestamp. We _should_ do the UTC conversion inside the CTE, but that's a lot
175
+ // of work for a very small edge case.
176
+ // We released the fix to write extracted_at in UTC before DST changed, so this is fine.
154
177
final Optional <String > maxTimestamp = Optional .ofNullable (database .queryStrings (
155
178
conn -> conn .createStatement ().executeQuery (new StringSubstitutor (Map .of (
156
179
"raw_table" , id .rawTableId (SnowflakeSqlGenerator .QUOTE ))).replace (
157
180
"""
158
- SELECT to_varchar(
159
- MAX("_airbyte_extracted_at"),
160
- 'YYYY-MM-DDTHH24:MI:SS.FF9TZH:TZM'
161
- ) AS MIN_TIMESTAMP
162
- FROM ${raw_table}
181
+ WITH MAX_TS AS (
182
+ SELECT MAX("_airbyte_extracted_at")
183
+ AS MAX_TIMESTAMP
184
+ FROM ${raw_table}
185
+ ) SELECT TO_VARCHAR(
186
+ TIMESTAMPADD(
187
+ HOUR,
188
+ EXTRACT(timezone_hour from MAX_TIMESTAMP),
189
+ TIMESTAMPADD(
190
+ MINUTE,
191
+ EXTRACT(timezone_minute from MAX_TIMESTAMP),
192
+ CONVERT_TIMEZONE('UTC', MAX_TIMESTAMP)
193
+ )
194
+ ),'YYYY-MM-DDTHH24:MI:SS.FF9TZH:TZM') as MAX_TIMESTAMP_UTC from MAX_TS;
163
195
""" )),
164
- record -> record .getString ("MIN_TIMESTAMP " )).get (0 ));
165
- return new InitialRawTableState ( false , maxTimestamp .map (Instant ::parse ));
196
+ record -> record .getString ("MAX_TIMESTAMP_UTC " )).get (0 ));
197
+ return new InitialRawTableStatus ( true , false , maxTimestamp .map (Instant ::parse ));
166
198
}
167
199
168
200
@ Override
@@ -171,7 +203,7 @@ public void execute(final Sql sql) throws Exception {
171
203
final UUID queryId = UUID .randomUUID ();
172
204
for (final String transaction : transactions ) {
173
205
final UUID transactionId = UUID .randomUUID ();
174
- LOGGER .debug ("Executing sql {}-{}: {}" , queryId , transactionId , transaction );
206
+ LOGGER .info ("Executing sql {}-{}: {}" , queryId , transactionId , transaction );
175
207
final long startTime = System .currentTimeMillis ();
176
208
177
209
try {
@@ -190,7 +222,7 @@ public void execute(final Sql sql) throws Exception {
190
222
throw new RuntimeException (trimmedMessage , e );
191
223
}
192
224
193
- LOGGER .debug ("Sql {}-{} completed in {} ms" , queryId , transactionId , System .currentTimeMillis () - startTime );
225
+ LOGGER .info ("Sql {}-{} completed in {} ms" , queryId , transactionId , System .currentTimeMillis () - startTime );
194
226
}
195
227
}
196
228
@@ -250,7 +282,9 @@ protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, f
250
282
}
251
283
252
284
@ Override
253
- public List <DestinationInitialState > gatherInitialState (List <StreamConfig > streamConfigs ) throws Exception {
285
+ public List <DestinationInitialStatus <SnowflakeState >> gatherInitialState (List <StreamConfig > streamConfigs ) throws Exception {
286
+ final Map <AirbyteStreamNameNamespacePair , SnowflakeState > destinationStates = super .getAllDestinationStates ();
287
+
254
288
List <StreamId > streamIds = streamConfigs .stream ().map (StreamConfig ::id ).toList ();
255
289
final LinkedHashMap <String , LinkedHashMap <String , TableDefinition >> existingTables = findExistingTables (database , databaseName , streamIds );
256
290
final LinkedHashMap <String , LinkedHashMap <String , Integer >> tableRowCounts = getFinalTableRowCount (streamIds );
@@ -267,8 +301,15 @@ public List<DestinationInitialState> gatherInitialState(List<StreamConfig> strea
267
301
isSchemaMismatch = !existingSchemaMatchesStreamConfig (streamConfig , existingTable );
268
302
isFinalTableEmpty = hasRowCount && tableRowCounts .get (namespace ).get (name ) == 0 ;
269
303
}
270
- final InitialRawTableState initialRawTableState = getInitialRawTableState (streamConfig .id ());
271
- return new DestinationInitialStateImpl (streamConfig , isFinalTablePresent , initialRawTableState , isSchemaMismatch , isFinalTableEmpty );
304
+ final InitialRawTableStatus initialRawTableState = getInitialRawTableState (streamConfig .id (), streamConfig .destinationSyncMode ());
305
+ final SnowflakeState destinationState = destinationStates .getOrDefault (streamConfig .id ().asPair (), toDestinationState (Jsons .emptyObject ()));
306
+ return new DestinationInitialStatus <>(
307
+ streamConfig ,
308
+ isFinalTablePresent ,
309
+ initialRawTableState ,
310
+ isSchemaMismatch ,
311
+ isFinalTableEmpty ,
312
+ destinationState );
272
313
} catch (Exception e ) {
273
314
throw new RuntimeException (e );
274
315
}
@@ -290,6 +331,12 @@ protected String toJdbcTypeName(AirbyteType airbyteType) {
290
331
};
291
332
}
292
333
334
+ @ Override
335
+ protected SnowflakeState toDestinationState (JsonNode json ) {
336
+ return new SnowflakeState (
337
+ json .hasNonNull ("needsSoftReset" ) && json .get ("needsSoftReset" ).asBoolean ());
338
+ }
339
+
293
340
private String toJdbcTypeName (final AirbyteProtocolType airbyteProtocolType ) {
294
341
return switch (airbyteProtocolType ) {
295
342
case STRING -> "TEXT" ;
0 commit comments