Skip to content

Commit e28e543

Browse files
committed
Correctly reading empty fields in as null rather than throwing exception (elastic#1816)
By default we intend to treat empty fields as nulls when being read in through spark sql. However we actually turn them into None objects, which causes spark-sql to blow up in spark 2 and 3. This commit treats them as nulls, which works for all versions of spark we currently support. Closes elastic#1635
1 parent 5f4e9e9 commit e28e543

File tree

6 files changed

+93
-8
lines changed

6 files changed

+93
-8
lines changed

spark/core/src/main/scala/org/elasticsearch/spark/serialization/ScalaValueReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class ScalaValueReader extends AbstractValueReader with SettingsAware {
126126
}
127127
}
128128

129-
def nullValue() = { None }
129+
def nullValue() = { null }
130130
def textValue(value: String, parser: Parser) = { checkNull (parseText, value, parser) }
131131
protected def parseText(value:String, parser: Parser) = { value }
132132

spark/core/src/test/scala/org/elasticsearch/spark/ScalaExtendedBooleanValueReaderTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ScalaExtendedBooleanValueReaderTest(jsonString: String, expected: Expected
4949

5050
def isNull: Matcher[AnyRef] = {
5151
return new BaseMatcher[AnyRef] {
52-
override def matches(item: scala.Any): Boolean = item == None
52+
override def matches(item: scala.Any): Boolean = item == null
5353
override def describeTo(description: Description): Unit = description.appendText("null")
5454
}
5555
}

spark/core/src/test/scala/org/elasticsearch/spark/ScalaValueReaderTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class ScalaValueReaderTest extends BaseValueReaderTest {
2626

2727
override def createValueReader() = new ScalaValueReader()
2828

29-
override def checkNull(result: Object): Unit = { assertEquals(None, result)}
30-
override def checkEmptyString(result: Object): Unit = { assertEquals(None, result)}
29+
override def checkNull(result: Object): Unit = { assertEquals(null, result)}
30+
override def checkEmptyString(result: Object): Unit = { assertEquals(null, result)}
3131
override def checkInteger(result: Object): Unit = { assertEquals(Int.MaxValue, result)}
3232
override def checkLong(result: Object): Unit = { assertEquals(Long.MaxValue, result)}
3333
override def checkDouble(result: Object): Unit = { assertEquals(Double.MaxValue, result)}

spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import java.{lang => jl}
2323
import java.sql.Timestamp
2424
import java.{util => ju}
2525
import java.util.concurrent.TimeUnit
26-
2726
import scala.collection.JavaConversions.propertiesAsScalaMap
2827
import scala.collection.JavaConverters.asScalaBufferConverter
2928
import scala.collection.JavaConverters.mapAsJavaMapConverter
@@ -68,6 +67,8 @@ import org.junit.runners.Parameterized.Parameters
6867
import org.junit.runners.MethodSorters
6968
import com.esotericsoftware.kryo.io.{Input => KryoInput}
7069
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
70+
import org.apache.spark.rdd.RDD
71+
7172
import javax.xml.bind.DatatypeConverter
7273
import org.elasticsearch.hadoop.{EsHadoopIllegalArgumentException, EsHadoopIllegalStateException}
7374
import org.apache.spark.sql.types.DoubleType
@@ -419,6 +420,33 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
419420
val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10")
420421
//results.take(5).foreach(println)
421422
}
423+
424+
@Test
425+
def testEmptyStrings(): Unit = {
426+
val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000"))
427+
val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2))
428+
val schema = StructType( Array(
429+
StructField("language", StringType,true),
430+
StructField("description", StringType,true)
431+
))
432+
val inputDf = sqc.createDataFrame(rdd, schema)
433+
inputDf.write
434+
.format("org.elasticsearch.spark.sql")
435+
.save("empty_strings_test")
436+
val reader = sqc.read.format("org.elasticsearch.spark.sql")
437+
val outputDf = reader.load("empty_strings_test")
438+
assertEquals(data.size, outputDf.count)
439+
val nullDescriptionsDf = outputDf.filter("language = 'Python'")
440+
assertEquals(1, nullDescriptionsDf.count)
441+
assertEquals(null, nullDescriptionsDf.first().getAs("description"))
442+
443+
val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no")
444+
val outputDf2 = reader2.load("empty_strings_test")
445+
assertEquals(data.size, outputDf2.count)
446+
val emptyDescriptionsDf = outputDf2.filter("language = 'Python'")
447+
assertEquals(1, emptyDescriptionsDf.count)
448+
assertEquals("", emptyDescriptionsDf.first().getAs("description"))
449+
}
422450

423451
@Test
424452
def test0WriteFieldNameWithPercentage() {

spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import java.nio.file.Paths
2727
import java.sql.Timestamp
2828
import java.{util => ju}
2929
import java.util.concurrent.TimeUnit
30-
3130
import scala.collection.JavaConversions.propertiesAsScalaMap
3231
import scala.collection.JavaConverters.asScalaBufferConverter
3332
import scala.collection.JavaConverters.mapAsJavaMapConverter
@@ -86,6 +85,8 @@ import org.junit.runners.Parameterized
8685
import org.junit.runners.Parameterized.Parameters
8786
import com.esotericsoftware.kryo.io.{Input => KryoInput}
8887
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
88+
import org.apache.spark.rdd.RDD
89+
8990
import javax.xml.bind.DatatypeConverter
9091
import org.apache.spark.sql.SparkSession
9192
import org.elasticsearch.hadoop.EsAssume
@@ -438,6 +439,33 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
438439
val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10")
439440
//results.take(5).foreach(println)
440441
}
442+
443+
@Test
444+
def testEmptyStrings(): Unit = {
445+
val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000"))
446+
val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2))
447+
val schema = StructType( Array(
448+
StructField("language", StringType,true),
449+
StructField("description", StringType,true)
450+
))
451+
val inputDf = sqc.createDataFrame(rdd, schema)
452+
inputDf.write
453+
.format("org.elasticsearch.spark.sql")
454+
.save("empty_strings_test")
455+
val reader = sqc.read.format("org.elasticsearch.spark.sql")
456+
val outputDf = reader.load("empty_strings_test")
457+
assertEquals(data.size, outputDf.count)
458+
val nullDescriptionsDf = outputDf.filter(row => row.getAs("description") == null)
459+
assertEquals(1, nullDescriptionsDf.count)
460+
461+
val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no")
462+
val outputDf2 = reader2.load("empty_strings_test")
463+
assertEquals(data.size, outputDf2.count)
464+
val nullDescriptionsDf2 = outputDf2.filter(row => row.getAs("description") == null)
465+
assertEquals(0, nullDescriptionsDf2.count)
466+
val emptyDescriptionsDf = outputDf2.filter(row => row.getAs("description") == "")
467+
assertEquals(1, emptyDescriptionsDf.count)
468+
}
441469

442470
@Test
443471
def test0WriteFieldNameWithPercentage() {

spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import java.nio.file.Paths
2727
import java.sql.Timestamp
2828
import java.{util => ju}
2929
import java.util.concurrent.TimeUnit
30-
3130
import scala.collection.JavaConversions.propertiesAsScalaMap
3231
import scala.collection.JavaConverters.asScalaBufferConverter
3332
import scala.collection.JavaConverters.mapAsJavaMapConverter
@@ -86,6 +85,8 @@ import org.junit.runners.Parameterized
8685
import org.junit.runners.Parameterized.Parameters
8786
import com.esotericsoftware.kryo.io.{Input => KryoInput}
8887
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
88+
import org.apache.spark.rdd.RDD
89+
8990
import javax.xml.bind.DatatypeConverter
9091
import org.apache.spark.sql.SparkSession
9192
import org.elasticsearch.hadoop.EsAssume
@@ -98,6 +99,7 @@ import org.junit.Assert._
9899
import org.junit.ClassRule
99100

100101
object AbstractScalaEsScalaSparkSQL {
102+
101103
@transient val conf = new SparkConf()
102104
.setAll(propertiesAsScalaMap(TestSettings.TESTING_PROPS))
103105
.setAppName("estest")
@@ -438,7 +440,34 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
438440
val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10")
439441
//results.take(5).foreach(println)
440442
}
441-
443+
444+
@Test
445+
def testEmptyStrings(): Unit = {
446+
val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000"))
447+
val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2))
448+
val schema = StructType( Array(
449+
StructField("language", StringType,true),
450+
StructField("description", StringType,true)
451+
))
452+
val inputDf = sqc.createDataFrame(rdd, schema)
453+
inputDf.write
454+
.format("org.elasticsearch.spark.sql")
455+
.save("empty_strings_test")
456+
val reader = sqc.read.format("org.elasticsearch.spark.sql")
457+
val outputDf = reader.load("empty_strings_test")
458+
assertEquals(data.size, outputDf.count)
459+
val nullDescriptionsDf = outputDf.filter(row => row.getAs("description") == null)
460+
assertEquals(1, nullDescriptionsDf.count)
461+
462+
val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no")
463+
val outputDf2 = reader2.load("empty_strings_test")
464+
assertEquals(data.size, outputDf2.count)
465+
val nullDescriptionsDf2 = outputDf2.filter(row => row.getAs("description") == null)
466+
assertEquals(0, nullDescriptionsDf2.count)
467+
val emptyDescriptionsDf = outputDf2.filter(row => row.getAs("description") == "")
468+
assertEquals(1, emptyDescriptionsDf.count)
469+
}
470+
442471
@Test
443472
def test0WriteFieldNameWithPercentage() {
444473
val index = wrapIndex("spark-test-scala-sql-field-with-percentage")

0 commit comments

Comments
 (0)