17
17
*/
18
18
package org .apache .beam .sdk .io .jdbc ;
19
19
20
- import static java .sql .JDBCType .NULL ;
21
20
import static org .apache .beam .sdk .io .common .DatabaseTestHelper .assertRowCount ;
22
21
import static org .hamcrest .MatcherAssert .assertThat ;
23
22
import static org .hamcrest .Matchers .closeTo ;
@@ -119,6 +118,8 @@ public class JdbcIOTest implements Serializable {
119
118
120
119
@ Rule public final transient TestPipeline pipeline = TestPipeline .create ();
121
120
121
+ @ Rule public final transient TestPipeline secondPipeline = TestPipeline .create ();
122
+
122
123
@ Rule public final transient ExpectedLogs expectedLogs = ExpectedLogs .none (JdbcIO .class );
123
124
124
125
@ Rule public transient ExpectedException thrown = ExpectedException .none ();
@@ -991,13 +992,13 @@ public void testGetPreparedStatementSetNullsCaller() throws Exception {
991
992
.set (row , psMocked , 10 , SchemaUtil .FieldWithIndex .of (schema .getField (8 ), 8 ));
992
993
993
994
// primitive
994
- verify (psMocked , times (1 )).setNull (1 , NULL .getVendorTypeNumber ());
995
- verify (psMocked , times (1 )).setNull (2 , NULL .getVendorTypeNumber ());
996
- verify (psMocked , times (1 )).setNull (3 , NULL .getVendorTypeNumber ());
997
- verify (psMocked , times (1 )).setNull (4 , NULL .getVendorTypeNumber ());
998
- verify (psMocked , times (1 )).setNull (5 , NULL .getVendorTypeNumber ());
999
- verify (psMocked , times (1 )).setNull (6 , NULL .getVendorTypeNumber ());
1000
- verify (psMocked , times (1 )).setNull (7 , NULL .getVendorTypeNumber ());
995
+ verify (psMocked , times (1 )).setNull (1 , JDBCType . BIGINT .getVendorTypeNumber ());
996
+ verify (psMocked , times (1 )).setNull (2 , JDBCType . BOOLEAN .getVendorTypeNumber ());
997
+ verify (psMocked , times (1 )).setNull (3 , JDBCType . DOUBLE .getVendorTypeNumber ());
998
+ verify (psMocked , times (1 )).setNull (4 , JDBCType . FLOAT .getVendorTypeNumber ());
999
+ verify (psMocked , times (1 )).setNull (5 , JDBCType . INTEGER .getVendorTypeNumber ());
1000
+ verify (psMocked , times (1 )).setNull (6 , JDBCType . SMALLINT .getVendorTypeNumber ());
1001
+ verify (psMocked , times (1 )).setNull (7 , JDBCType . TINYINT .getVendorTypeNumber ());
1001
1002
// reference
1002
1003
verify (psMocked , times (1 )).setBytes (8 , null );
1003
1004
verify (psMocked , times (1 )).setString (9 , null );
@@ -1095,20 +1096,32 @@ public void testGetPreparedStatementSetCallerForArray() throws Exception {
1095
1096
verify (psMocked , times (1 )).setArray (1 , arrayMocked );
1096
1097
}
1097
1098
1098
- private static ArrayList <Row > getRowsToWrite (long rowsToAdd , Schema schema ) {
1099
+ private static ArrayList <Row > getRowsToWrite (long rowsToAdd , Schema schema , boolean hasNulls ) {
1099
1100
1100
1101
ArrayList <Row > data = new ArrayList <>();
1102
+ int numFields = schema .getFields ().size ();
1101
1103
for (int i = 0 ; i < rowsToAdd ; i ++) {
1102
-
1103
- Row row =
1104
- schema .getFields ().stream ()
1105
- .map (field -> dummyFieldValue (field .getType ()))
1106
- .collect (Row .toRow (schema ));
1107
- data .add (row );
1104
+ Row .Builder builder = Row .withSchema (schema );
1105
+ for (int j = 0 ; j < numFields ; j ++) {
1106
+ if (hasNulls && i % numFields == j && schema .getField (j ).getType ().getNullable ()) {
1107
+ builder .addValue (null );
1108
+ } else {
1109
+ builder .addValue (dummyFieldValue (schema .getField (j ).getType ()));
1110
+ }
1111
+ }
1112
+ data .add (builder .build ());
1108
1113
}
1109
1114
return data ;
1110
1115
}
1111
1116
1117
+ private static ArrayList <Row > getRowsToWrite (long rowsToAdd , Schema schema ) {
1118
+ return getRowsToWrite (rowsToAdd , schema , false );
1119
+ }
1120
+
1121
+ private static ArrayList <Row > getNullableRowsToWrite (long rowsToAdd , Schema schema ) {
1122
+ return getRowsToWrite (rowsToAdd , schema , true );
1123
+ }
1124
+
1112
1125
private static ArrayList <RowWithSchema > getRowsWithSchemaToWrite (long rowsToAdd ) {
1113
1126
1114
1127
ArrayList <RowWithSchema > data = new ArrayList <>();
@@ -1118,7 +1131,8 @@ private static ArrayList<RowWithSchema> getRowsWithSchemaToWrite(long rowsToAdd)
1118
1131
return data ;
1119
1132
}
1120
1133
1121
- private static Object dummyFieldValue (Schema .FieldType fieldType ) {
1134
+ private static Object dummyFieldValue (Schema .FieldType maybeNullableType ) {
1135
+ Schema .FieldType fieldType = maybeNullableType .withNullable (false );
1122
1136
long epochMilli = 1558719710000L ;
1123
1137
if (fieldType .equals (Schema .FieldType .STRING )) {
1124
1138
return "string value" ;
@@ -1134,7 +1148,12 @@ private static Object dummyFieldValue(Schema.FieldType fieldType) {
1134
1148
return Long .MAX_VALUE ;
1135
1149
} else if (fieldType .equals (Schema .FieldType .FLOAT )) {
1136
1150
return 15.5F ;
1137
- } else if (fieldType .equals (Schema .FieldType .DECIMAL )) {
1151
+ } else if (fieldType .equals (Schema .FieldType .DECIMAL )
1152
+ || (fieldType .getLogicalType () != null
1153
+ && fieldType
1154
+ .getLogicalType ()
1155
+ .getIdentifier ()
1156
+ .equals (FixedPrecisionNumeric .IDENTIFIER ))) {
1138
1157
return BigDecimal .ONE ;
1139
1158
} else if (fieldType .equals (LogicalTypes .JDBC_DATE_TYPE )) {
1140
1159
return new DateTime (epochMilli , ISOChronology .getInstanceUTC ()).withTimeAtStartOfDay ();
@@ -1326,6 +1345,63 @@ public Void apply(Iterable<Long> input) {
1326
1345
pipeline .run ().waitUntilFinish ();
1327
1346
}
1328
1347
1348
+ @ Test
1349
+ public void testWriteReadNullableTypes () throws SQLException {
1350
+ // first setup data
1351
+ Schema .Builder schemaBuilder = Schema .builder ();
1352
+ schemaBuilder .addField ("column_id" , FieldType .INT32 .withNullable (false ));
1353
+ schemaBuilder .addField ("column_bigint" , Schema .FieldType .INT64 .withNullable (true ));
1354
+ schemaBuilder .addField ("column_boolean" , FieldType .BOOLEAN .withNullable (true ));
1355
+ schemaBuilder .addField ("column_float" , Schema .FieldType .FLOAT .withNullable (true ));
1356
+ schemaBuilder .addField ("column_double" , Schema .FieldType .DOUBLE .withNullable (true ));
1357
+ schemaBuilder .addField (
1358
+ "column_decimal" ,
1359
+ FieldType .logicalType (FixedPrecisionNumeric .of (13 , 0 )).withNullable (true ));
1360
+ Schema schema = schemaBuilder .build ();
1361
+
1362
+ // some types not supported in derby (e.g. tinyint) are not tested here
1363
+ String tableName = DatabaseTestHelper .getTestTableName ("UT_READ_NULLABLE_LG" );
1364
+ StringBuilder stmt = new StringBuilder ("CREATE TABLE " );
1365
+ stmt .append (tableName );
1366
+ stmt .append (" (" );
1367
+ stmt .append ("column_id INTEGER NOT NULL," ); // Integer
1368
+ stmt .append ("column_bigint BIGINT," ); // int64
1369
+ stmt .append ("column_boolean BOOLEAN," ); // boolean
1370
+ stmt .append ("column_float REAL," ); // float
1371
+ stmt .append ("column_double DOUBLE PRECISION," ); // double
1372
+ stmt .append ("column_decimal DECIMAL(13,0)" ); // BigDecimal
1373
+ stmt .append (" )" );
1374
+ DatabaseTestHelper .createTableWithStatement (DATA_SOURCE , stmt .toString ());
1375
+ final int rowsToAdd = 10 ;
1376
+ try {
1377
+ // run write pipeline
1378
+ ArrayList <Row > data = getNullableRowsToWrite (rowsToAdd , schema );
1379
+ pipeline
1380
+ .apply (Create .of (data ))
1381
+ .setRowSchema (schema )
1382
+ .apply (
1383
+ JdbcIO .<Row >write ()
1384
+ .withDataSourceConfiguration (DATA_SOURCE_CONFIGURATION )
1385
+ .withBatchSize (10L )
1386
+ .withTable (tableName ));
1387
+ pipeline .run ();
1388
+ assertRowCount (DATA_SOURCE , tableName , rowsToAdd );
1389
+
1390
+ // run read pipeline
1391
+ PCollection <Row > rows =
1392
+ secondPipeline .apply (
1393
+ JdbcIO .readRows ()
1394
+ .withDataSourceConfiguration (DATA_SOURCE_CONFIGURATION )
1395
+ .withQuery ("SELECT * FROM " + tableName ));
1396
+ PAssert .thatSingleton (rows .apply ("Count All" , Count .globally ())).isEqualTo ((long ) rowsToAdd );
1397
+ PAssert .that (rows ).containsInAnyOrder (data );
1398
+
1399
+ secondPipeline .run ();
1400
+ } finally {
1401
+ DatabaseTestHelper .deleteTable (DATA_SOURCE , tableName );
1402
+ }
1403
+ }
1404
+
1329
1405
@ Test
1330
1406
public void testPartitioningLongs () {
1331
1407
PCollection <KV <Long , Long >> ranges =
0 commit comments