22
22
import io .airbyte .commons .util .MoreIterators ;
23
23
import io .airbyte .db .Databases ;
24
24
import io .airbyte .db .jdbc .JdbcDatabase ;
25
+ import io .airbyte .integrations .base .Source ;
25
26
import io .airbyte .integrations .source .jdbc .AbstractJdbcSource ;
26
27
import io .airbyte .integrations .source .jdbc .SourceJdbcUtils ;
27
28
import io .airbyte .integrations .source .relationaldb .models .DbState ;
50
51
import java .util .List ;
51
52
import java .util .Optional ;
52
53
import java .util .Set ;
54
+ import java .util .function .Function ;
53
55
import java .util .stream .Collectors ;
54
56
import org .hamcrest .Matchers ;
55
57
import org .junit .jupiter .api .Test ;
@@ -95,7 +97,7 @@ public abstract class JdbcSourceAcceptanceTest {
95
97
96
98
public JsonNode config ;
97
99
public JdbcDatabase database ;
98
- public AbstractJdbcSource source ;
100
+ public Source source ;
99
101
public static String streamName ;
100
102
101
103
/**
@@ -126,21 +128,43 @@ public abstract class JdbcSourceAcceptanceTest {
126
128
/**
127
129
* An instance of the source that should be tests.
128
130
*
131
+ * @return abstract jdbc source
132
+ */
133
+ public abstract AbstractJdbcSource getJdbcSource ();
134
+
135
+ /**
136
+ * In some cases the Source that is being tested may be an AbstractJdbcSource, but because it is
137
+ * decorated, Java cannot recognize it as such. In these cases, as a workaround a user can choose to
138
+ * override getJdbcSource and have it return null. Then they can override this method with the
139
+ * decorated source AND override getToDatabaseConfigFunction with the appropriate
140
+ * toDatabaseConfigFunction that is hidden behind the decorator.
141
+ *
129
142
* @return source
130
143
*/
131
- public abstract AbstractJdbcSource getSource ();
144
+ public Source getSource () {
145
+ return getJdbcSource ();
146
+ }
132
147
133
- protected String createTableQuery (String tableName , String columnClause , String primaryKeyClause ) {
148
+ /**
149
+ * See getSource() for when to override this method.
150
+ *
151
+ * @return a function that maps a source's config to a jdbc config.
152
+ */
153
+ public Function <JsonNode , JsonNode > getToDatabaseConfigFunction () {
154
+ return getJdbcSource ()::toDatabaseConfig ;
155
+ }
156
+
157
+ protected String createTableQuery (final String tableName , final String columnClause , final String primaryKeyClause ) {
134
158
return String .format ("CREATE TABLE %s(%s %s %s)" ,
135
159
tableName , columnClause , primaryKeyClause .equals ("" ) ? "" : "," , primaryKeyClause );
136
160
}
137
161
138
- protected String primaryKeyClause (List <String > columns ) {
162
+ protected String primaryKeyClause (final List <String > columns ) {
139
163
if (columns .isEmpty ()) {
140
164
return "" ;
141
165
}
142
166
143
- StringBuilder clause = new StringBuilder ();
167
+ final StringBuilder clause = new StringBuilder ();
144
168
clause .append ("PRIMARY KEY (" );
145
169
for (int i = 0 ; i < columns .size (); i ++) {
146
170
clause .append (columns .get (i ));
@@ -155,7 +179,7 @@ protected String primaryKeyClause(List<String> columns) {
155
179
public void setup () throws Exception {
156
180
source = getSource ();
157
181
config = getConfig ();
158
- final JsonNode jdbcConfig = source . toDatabaseConfig (config );
182
+ final JsonNode jdbcConfig = getToDatabaseConfigFunction (). apply (config );
159
183
160
184
streamName = TABLE_NAME ;
161
185
@@ -253,7 +277,7 @@ void testCheckFailure() throws Exception {
253
277
@ Test
254
278
void testDiscover () throws Exception {
255
279
final AirbyteCatalog actual = filterOutOtherSchemas (source .discover (config ));
256
- AirbyteCatalog expected = getCatalog (getDefaultNamespace ());
280
+ final AirbyteCatalog expected = getCatalog (getDefaultNamespace ());
257
281
assertEquals (expected .getStreams ().size (), actual .getStreams ().size ());
258
282
actual .getStreams ().forEach (actualStream -> {
259
283
final Optional <AirbyteStream > expectedStream =
@@ -265,7 +289,7 @@ void testDiscover() throws Exception {
265
289
});
266
290
}
267
291
268
- protected AirbyteCatalog filterOutOtherSchemas (AirbyteCatalog catalog ) {
292
+ protected AirbyteCatalog filterOutOtherSchemas (final AirbyteCatalog catalog ) {
269
293
if (supportsSchemas ()) {
270
294
final AirbyteCatalog filteredCatalog = Jsons .clone (catalog );
271
295
filteredCatalog .setStreams (filteredCatalog .getStreams ()
@@ -312,7 +336,7 @@ void testDiscoverWithMultipleSchemas() throws Exception {
312
336
Field .of (COL_NAME , JsonSchemaPrimitive .STRING ))
313
337
.withSupportedSyncModes (Lists .newArrayList (SyncMode .FULL_REFRESH , SyncMode .INCREMENTAL )));
314
338
// sort streams by name so that we are comparing lists with the same order.
315
- Comparator <AirbyteStream > schemaTableCompare = Comparator .comparing (stream -> stream .getNamespace () + "." + stream .getName ());
339
+ final Comparator <AirbyteStream > schemaTableCompare = Comparator .comparing (stream -> stream .getNamespace () + "." + stream .getName ());
316
340
expected .getStreams ().sort (schemaTableCompare );
317
341
actual .getStreams ().sort (schemaTableCompare );
318
342
assertEquals (expected , filterOutOtherSchemas (actual ));
@@ -325,7 +349,7 @@ void testReadSuccess() throws Exception {
325
349
source .read (config , getConfiguredCatalogWithOneStream (getDefaultNamespace ()), null ));
326
350
327
351
setEmittedAtToNull (actualMessages );
328
- List <AirbyteMessage > expectedMessages = getTestMessages ();
352
+ final List <AirbyteMessage > expectedMessages = getTestMessages ();
329
353
assertThat (expectedMessages , Matchers .containsInAnyOrder (actualMessages .toArray ()));
330
354
assertThat (actualMessages , Matchers .containsInAnyOrder (expectedMessages .toArray ()));
331
355
}
@@ -596,7 +620,7 @@ void testReadOneTableIncrementallyTwice() throws Exception {
596
620
@ Test
597
621
void testReadMultipleTablesIncrementally () throws Exception {
598
622
final String tableName2 = TABLE_NAME + 2 ;
599
- String streamName2 = streamName + 2 ;
623
+ final String streamName2 = streamName + 2 ;
600
624
database .execute (ctx -> {
601
625
ctx .createStatement ().execute (
602
626
createTableQuery (getFullyQualifiedTableName (tableName2 ), "id INTEGER, name VARCHAR(200)" , "" ));
@@ -692,34 +716,34 @@ void testReadMultipleTablesIncrementally() throws Exception {
692
716
693
717
// when initial and final cursor fields are the same.
694
718
private void incrementalCursorCheck (
695
- String cursorField ,
696
- String initialCursorValue ,
697
- String endCursorValue ,
698
- List <AirbyteMessage > expectedRecordMessages )
719
+ final String cursorField ,
720
+ final String initialCursorValue ,
721
+ final String endCursorValue ,
722
+ final List <AirbyteMessage > expectedRecordMessages )
699
723
throws Exception {
700
724
incrementalCursorCheck (cursorField , cursorField , initialCursorValue , endCursorValue ,
701
725
expectedRecordMessages );
702
726
}
703
727
704
728
private void incrementalCursorCheck (
705
- String initialCursorField ,
706
- String cursorField ,
707
- String initialCursorValue ,
708
- String endCursorValue ,
709
- List <AirbyteMessage > expectedRecordMessages )
729
+ final String initialCursorField ,
730
+ final String cursorField ,
731
+ final String initialCursorValue ,
732
+ final String endCursorValue ,
733
+ final List <AirbyteMessage > expectedRecordMessages )
710
734
throws Exception {
711
735
incrementalCursorCheck (initialCursorField , cursorField , initialCursorValue , endCursorValue ,
712
736
expectedRecordMessages ,
713
737
getConfiguredCatalogWithOneStream (getDefaultNamespace ()).getStreams ().get (0 ));
714
738
}
715
739
716
740
private void incrementalCursorCheck (
717
- String initialCursorField ,
718
- String cursorField ,
719
- String initialCursorValue ,
720
- String endCursorValue ,
721
- List <AirbyteMessage > expectedRecordMessages ,
722
- ConfiguredAirbyteStream airbyteStream )
741
+ final String initialCursorField ,
742
+ final String cursorField ,
743
+ final String initialCursorValue ,
744
+ final String endCursorValue ,
745
+ final List <AirbyteMessage > expectedRecordMessages ,
746
+ final ConfiguredAirbyteStream airbyteStream )
723
747
throws Exception {
724
748
airbyteStream .setSyncMode (SyncMode .INCREMENTAL );
725
749
airbyteStream .setCursorField (Lists .newArrayList (cursorField ));
@@ -856,13 +880,13 @@ protected ConfiguredAirbyteStream createTableWithSpaces() throws SQLException {
856
880
Field .of (COL_LAST_NAME_WITH_SPACE , JsonSchemaPrimitive .STRING ));
857
881
}
858
882
859
- public String getFullyQualifiedTableName (String tableName ) {
883
+ public String getFullyQualifiedTableName (final String tableName ) {
860
884
return SourceJdbcUtils .getFullyQualifiedTableName (getDefaultSchemaName (), tableName );
861
885
}
862
886
863
887
public void createSchemas () throws SQLException {
864
888
if (supportsSchemas ()) {
865
- for (String schemaName : TEST_SCHEMAS ) {
889
+ for (final String schemaName : TEST_SCHEMAS ) {
866
890
final String createSchemaQuery = String .format ("CREATE SCHEMA %s;" , schemaName );
867
891
database .execute (connection -> connection .createStatement ().execute (createSchemaQuery ));
868
892
}
@@ -871,15 +895,15 @@ public void createSchemas() throws SQLException {
871
895
872
896
public void dropSchemas () throws SQLException {
873
897
if (supportsSchemas ()) {
874
- for (String schemaName : TEST_SCHEMAS ) {
898
+ for (final String schemaName : TEST_SCHEMAS ) {
875
899
final String dropSchemaQuery = String
876
900
.format (DROP_SCHEMA_QUERY , schemaName );
877
901
database .execute (connection -> connection .createStatement ().execute (dropSchemaQuery ));
878
902
}
879
903
}
880
904
}
881
905
882
- private JsonNode convertIdBasedOnDatabase (int idValue ) {
906
+ private JsonNode convertIdBasedOnDatabase (final int idValue ) {
883
907
if (getDriverClass ().toLowerCase ().contains ("oracle" )) {
884
908
return Jsons .jsonNode (BigDecimal .valueOf (idValue ));
885
909
} else if (getDriverClass ().toLowerCase ().contains ("snowflake" )) {
@@ -902,8 +926,8 @@ protected String getDefaultNamespace() {
902
926
}
903
927
}
904
928
905
- protected static void setEmittedAtToNull (Iterable <AirbyteMessage > messages ) {
906
- for (AirbyteMessage actualMessage : messages ) {
929
+ protected static void setEmittedAtToNull (final Iterable <AirbyteMessage > messages ) {
930
+ for (final AirbyteMessage actualMessage : messages ) {
907
931
if (actualMessage .getRecord () != null ) {
908
932
actualMessage .getRecord ().setEmittedAt (null );
909
933
}
0 commit comments