Skip to content

Commit 4cc0bcc

Browse files
authored
Fix excessive checkStateNotNull in JdbcUtil (#25847)
* Also fix cannot write null values in derby (unit test)
1 parent 5ffb2d5 commit 4cc0bcc

File tree

3 files changed

+107
-28
lines changed

3 files changed

+107
-28
lines changed

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
8585
(element, ps, i, fieldWithIndex) -> {
8686
Byte value = element.getByte(fieldWithIndex.getIndex());
8787
if (value == null) {
88-
setNullToPreparedStatement(ps, i);
88+
setNullToPreparedStatement(ps, i, JDBCType.TINYINT);
8989
} else {
9090
ps.setByte(i + 1, value);
9191
}
@@ -95,7 +95,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
9595
(element, ps, i, fieldWithIndex) -> {
9696
Short value = element.getInt16(fieldWithIndex.getIndex());
9797
if (value == null) {
98-
setNullToPreparedStatement(ps, i);
98+
setNullToPreparedStatement(ps, i, JDBCType.SMALLINT);
9999
} else {
100100
ps.setInt(i + 1, value);
101101
}
@@ -105,7 +105,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
105105
(element, ps, i, fieldWithIndex) -> {
106106
Long value = element.getInt64(fieldWithIndex.getIndex());
107107
if (value == null) {
108-
setNullToPreparedStatement(ps, i);
108+
setNullToPreparedStatement(ps, i, JDBCType.BIGINT);
109109
} else {
110110
ps.setLong(i + 1, value);
111111
}
@@ -119,7 +119,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
119119
(element, ps, i, fieldWithIndex) -> {
120120
Float value = element.getFloat(fieldWithIndex.getIndex());
121121
if (value == null) {
122-
setNullToPreparedStatement(ps, i);
122+
setNullToPreparedStatement(ps, i, JDBCType.FLOAT);
123123
} else {
124124
ps.setFloat(i + 1, value);
125125
}
@@ -129,7 +129,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
129129
(element, ps, i, fieldWithIndex) -> {
130130
Double value = element.getDouble(fieldWithIndex.getIndex());
131131
if (value == null) {
132-
setNullToPreparedStatement(ps, i);
132+
setNullToPreparedStatement(ps, i, JDBCType.DOUBLE);
133133
} else {
134134
ps.setDouble(i + 1, value);
135135
}
@@ -145,7 +145,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
145145
(element, ps, i, fieldWithIndex) -> {
146146
Boolean value = element.getBoolean(fieldWithIndex.getIndex());
147147
if (value == null) {
148-
setNullToPreparedStatement(ps, i);
148+
setNullToPreparedStatement(ps, i, JDBCType.BOOLEAN);
149149
} else {
150150
ps.setBoolean(i + 1, value);
151151
}
@@ -156,7 +156,7 @@ static String generateStatement(String tableName, List<Schema.Field> fields) {
156156
(element, ps, i, fieldWithIndex) -> {
157157
Integer value = element.getInt32(fieldWithIndex.getIndex());
158158
if (value == null) {
159-
setNullToPreparedStatement(ps, i);
159+
setNullToPreparedStatement(ps, i, JDBCType.INTEGER);
160160
} else {
161161
ps.setInt(i + 1, value);
162162
}
@@ -267,8 +267,9 @@ private static void setArrayNull(PreparedStatement ps, int i) throws SQLExceptio
267267
ps.setArray(i + 1, null);
268268
}
269269

270-
static void setNullToPreparedStatement(PreparedStatement ps, int i) throws SQLException {
271-
ps.setNull(i + 1, JDBCType.NULL.getVendorTypeNumber());
270+
static void setNullToPreparedStatement(PreparedStatement ps, int i, JDBCType type)
271+
throws SQLException {
272+
ps.setNull(i + 1, type.getVendorTypeNumber());
272273
}
273274

274275
static class BeamRowPreparedStatementSetter implements JdbcIO.PreparedStatementSetter<Row> {

sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import static java.sql.JDBCType.VARBINARY;
2828
import static java.sql.JDBCType.VARCHAR;
2929
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
30-
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
3130
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
3231

3332
import java.io.Serializable;
@@ -304,7 +303,10 @@ private static <InputT, BaseT> ResultSetFieldExtractor createLogicalTypeExtracto
304303
} else {
305304
ResultSetFieldExtractor extractor = createFieldExtractor(fieldType.getBaseType());
306305
return (rs, index) -> {
307-
BaseT v = checkStateNotNull((BaseT) extractor.extract(rs, index));
306+
BaseT v = (BaseT) extractor.extract(rs, index);
307+
if (v == null) {
308+
return null;
309+
}
308310
return fieldType.toInputType(v);
309311
};
310312
}

sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
*/
1818
package org.apache.beam.sdk.io.jdbc;
1919

20-
import static java.sql.JDBCType.NULL;
2120
import static org.apache.beam.sdk.io.common.DatabaseTestHelper.assertRowCount;
2221
import static org.hamcrest.MatcherAssert.assertThat;
2322
import static org.hamcrest.Matchers.closeTo;
@@ -119,6 +118,8 @@ public class JdbcIOTest implements Serializable {
119118

120119
@Rule public final transient TestPipeline pipeline = TestPipeline.create();
121120

121+
@Rule public final transient TestPipeline secondPipeline = TestPipeline.create();
122+
122123
@Rule public final transient ExpectedLogs expectedLogs = ExpectedLogs.none(JdbcIO.class);
123124

124125
@Rule public transient ExpectedException thrown = ExpectedException.none();
@@ -991,13 +992,13 @@ public void testGetPreparedStatementSetNullsCaller() throws Exception {
991992
.set(row, psMocked, 10, SchemaUtil.FieldWithIndex.of(schema.getField(8), 8));
992993

993994
// primitive
994-
verify(psMocked, times(1)).setNull(1, NULL.getVendorTypeNumber());
995-
verify(psMocked, times(1)).setNull(2, NULL.getVendorTypeNumber());
996-
verify(psMocked, times(1)).setNull(3, NULL.getVendorTypeNumber());
997-
verify(psMocked, times(1)).setNull(4, NULL.getVendorTypeNumber());
998-
verify(psMocked, times(1)).setNull(5, NULL.getVendorTypeNumber());
999-
verify(psMocked, times(1)).setNull(6, NULL.getVendorTypeNumber());
1000-
verify(psMocked, times(1)).setNull(7, NULL.getVendorTypeNumber());
995+
verify(psMocked, times(1)).setNull(1, JDBCType.BIGINT.getVendorTypeNumber());
996+
verify(psMocked, times(1)).setNull(2, JDBCType.BOOLEAN.getVendorTypeNumber());
997+
verify(psMocked, times(1)).setNull(3, JDBCType.DOUBLE.getVendorTypeNumber());
998+
verify(psMocked, times(1)).setNull(4, JDBCType.FLOAT.getVendorTypeNumber());
999+
verify(psMocked, times(1)).setNull(5, JDBCType.INTEGER.getVendorTypeNumber());
1000+
verify(psMocked, times(1)).setNull(6, JDBCType.SMALLINT.getVendorTypeNumber());
1001+
verify(psMocked, times(1)).setNull(7, JDBCType.TINYINT.getVendorTypeNumber());
10011002
// reference
10021003
verify(psMocked, times(1)).setBytes(8, null);
10031004
verify(psMocked, times(1)).setString(9, null);
@@ -1095,20 +1096,32 @@ public void testGetPreparedStatementSetCallerForArray() throws Exception {
10951096
verify(psMocked, times(1)).setArray(1, arrayMocked);
10961097
}
10971098

1098-
private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema) {
1099+
private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema, boolean hasNulls) {
10991100

11001101
ArrayList<Row> data = new ArrayList<>();
1102+
int numFields = schema.getFields().size();
11011103
for (int i = 0; i < rowsToAdd; i++) {
1102-
1103-
Row row =
1104-
schema.getFields().stream()
1105-
.map(field -> dummyFieldValue(field.getType()))
1106-
.collect(Row.toRow(schema));
1107-
data.add(row);
1104+
Row.Builder builder = Row.withSchema(schema);
1105+
for (int j = 0; j < numFields; j++) {
1106+
if (hasNulls && i % numFields == j && schema.getField(j).getType().getNullable()) {
1107+
builder.addValue(null);
1108+
} else {
1109+
builder.addValue(dummyFieldValue(schema.getField(j).getType()));
1110+
}
1111+
}
1112+
data.add(builder.build());
11081113
}
11091114
return data;
11101115
}
11111116

1117+
private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema) {
1118+
return getRowsToWrite(rowsToAdd, schema, false);
1119+
}
1120+
1121+
private static ArrayList<Row> getNullableRowsToWrite(long rowsToAdd, Schema schema) {
1122+
return getRowsToWrite(rowsToAdd, schema, true);
1123+
}
1124+
11121125
private static ArrayList<RowWithSchema> getRowsWithSchemaToWrite(long rowsToAdd) {
11131126

11141127
ArrayList<RowWithSchema> data = new ArrayList<>();
@@ -1118,7 +1131,8 @@ private static ArrayList<RowWithSchema> getRowsWithSchemaToWrite(long rowsToAdd)
11181131
return data;
11191132
}
11201133

1121-
private static Object dummyFieldValue(Schema.FieldType fieldType) {
1134+
private static Object dummyFieldValue(Schema.FieldType maybeNullableType) {
1135+
Schema.FieldType fieldType = maybeNullableType.withNullable(false);
11221136
long epochMilli = 1558719710000L;
11231137
if (fieldType.equals(Schema.FieldType.STRING)) {
11241138
return "string value";
@@ -1134,7 +1148,12 @@ private static Object dummyFieldValue(Schema.FieldType fieldType) {
11341148
return Long.MAX_VALUE;
11351149
} else if (fieldType.equals(Schema.FieldType.FLOAT)) {
11361150
return 15.5F;
1137-
} else if (fieldType.equals(Schema.FieldType.DECIMAL)) {
1151+
} else if (fieldType.equals(Schema.FieldType.DECIMAL)
1152+
|| (fieldType.getLogicalType() != null
1153+
&& fieldType
1154+
.getLogicalType()
1155+
.getIdentifier()
1156+
.equals(FixedPrecisionNumeric.IDENTIFIER))) {
11381157
return BigDecimal.ONE;
11391158
} else if (fieldType.equals(LogicalTypes.JDBC_DATE_TYPE)) {
11401159
return new DateTime(epochMilli, ISOChronology.getInstanceUTC()).withTimeAtStartOfDay();
@@ -1326,6 +1345,63 @@ public Void apply(Iterable<Long> input) {
13261345
pipeline.run().waitUntilFinish();
13271346
}
13281347

1348+
@Test
1349+
public void testWriteReadNullableTypes() throws SQLException {
1350+
// first setup data
1351+
Schema.Builder schemaBuilder = Schema.builder();
1352+
schemaBuilder.addField("column_id", FieldType.INT32.withNullable(false));
1353+
schemaBuilder.addField("column_bigint", Schema.FieldType.INT64.withNullable(true));
1354+
schemaBuilder.addField("column_boolean", FieldType.BOOLEAN.withNullable(true));
1355+
schemaBuilder.addField("column_float", Schema.FieldType.FLOAT.withNullable(true));
1356+
schemaBuilder.addField("column_double", Schema.FieldType.DOUBLE.withNullable(true));
1357+
schemaBuilder.addField(
1358+
"column_decimal",
1359+
FieldType.logicalType(FixedPrecisionNumeric.of(13, 0)).withNullable(true));
1360+
Schema schema = schemaBuilder.build();
1361+
1362+
// some types not supported in derby (e.g. tinyint) are not tested here
1363+
String tableName = DatabaseTestHelper.getTestTableName("UT_READ_NULLABLE_LG");
1364+
StringBuilder stmt = new StringBuilder("CREATE TABLE ");
1365+
stmt.append(tableName);
1366+
stmt.append(" (");
1367+
stmt.append("column_id INTEGER NOT NULL,"); // Integer
1368+
stmt.append("column_bigint BIGINT,"); // int64
1369+
stmt.append("column_boolean BOOLEAN,"); // boolean
1370+
stmt.append("column_float REAL,"); // float
1371+
stmt.append("column_double DOUBLE PRECISION,"); // double
1372+
stmt.append("column_decimal DECIMAL(13,0)"); // BigDecimal
1373+
stmt.append(" )");
1374+
DatabaseTestHelper.createTableWithStatement(DATA_SOURCE, stmt.toString());
1375+
final int rowsToAdd = 10;
1376+
try {
1377+
// run write pipeline
1378+
ArrayList<Row> data = getNullableRowsToWrite(rowsToAdd, schema);
1379+
pipeline
1380+
.apply(Create.of(data))
1381+
.setRowSchema(schema)
1382+
.apply(
1383+
JdbcIO.<Row>write()
1384+
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
1385+
.withBatchSize(10L)
1386+
.withTable(tableName));
1387+
pipeline.run();
1388+
assertRowCount(DATA_SOURCE, tableName, rowsToAdd);
1389+
1390+
// run read pipeline
1391+
PCollection<Row> rows =
1392+
secondPipeline.apply(
1393+
JdbcIO.readRows()
1394+
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
1395+
.withQuery("SELECT * FROM " + tableName));
1396+
PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo((long) rowsToAdd);
1397+
PAssert.that(rows).containsInAnyOrder(data);
1398+
1399+
secondPipeline.run();
1400+
} finally {
1401+
DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
1402+
}
1403+
}
1404+
13291405
@Test
13301406
public void testPartitioningLongs() {
13311407
PCollection<KV<Long, Long>> ranges =

0 commit comments

Comments
 (0)