diff --git a/.circleci/config.yml b/.circleci/config.yml index c4779d6a5..87c1c6aed 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -78,6 +78,36 @@ jobs: destination: spark_warehouse.tar.gz when: on_fail + # run these separately as we need a isolated JVM to not have Spark session settings interfere with other runs + # long term goal is to refactor the current testing spark session builder and avoid adding new single test to CI + "Scala 13 -- Delta Lake Format Tests": + executor: docker_baseimg_executor + steps: + - checkout + - run: + name: Run Scala 13 tests for Delta Lake format + environment: + format_test: deltalake + shell: /bin/bash -leuxo pipefail + command: | + conda activate chronon_py + # Increase if we see OOM. + export SBT_OPTS="-XX:+CMSClassUnloadingEnabled -XX:MaxPermSize=4G -Xmx4G -Xms2G" + sbt '++ 2.13.6' "testOnly ai.chronon.spark.test.TableUtilsFormatTest" + - store_test_results: + path: /chronon/spark/target/test-reports + - store_test_results: + path: /chronon/aggregator/target/test-reports + - run: + name: Compress spark-warehouse + command: | + cd /tmp/ && tar -czvf spark-warehouse.tar.gz chronon/spark-warehouse + when: on_fail + - store_artifacts: + path: /tmp/spark-warehouse.tar.gz + destination: spark_warehouse.tar.gz + when: on_fail + "Scala 11 -- Compile": executor: docker_baseimg_executor steps: @@ -147,6 +177,9 @@ workflows: - "Scala 13 -- Tests": requires: - "Pull Docker Image" + - "Scala 13 -- Delta Lake Format Tests": + requires: + - "Pull Docker Image" - "Scalafmt Check": requires: - "Pull Docker Image" diff --git a/build.sbt b/build.sbt index 8b1098c02..381a76759 100644 --- a/build.sbt +++ b/build.sbt @@ -156,6 +156,14 @@ val VersionMatrix: Map[String, VersionDependency] = Map( None, Some("1.0.4") ), + "delta-core" -> VersionDependency( + Seq( + "io.delta" %% "delta-core" + ), + Some("0.6.1"), + Some("1.0.1"), + Some("2.0.2") + ), "jackson" -> VersionDependency( Seq( "com.fasterxml.jackson.core" % "jackson-core", @@ -365,7 +373,7 @@ lazy val spark_uber = (project in file("spark")) sparkBaseSettings, version := git.versionProperty.value, crossScalaVersions := supportedVersions, - libraryDependencies ++= fromMatrix(scalaVersion.value, "jackson", "spark-all/provided") + libraryDependencies ++= fromMatrix(scalaVersion.value, "jackson", "spark-all/provided", "delta-core/provided") ) lazy val spark_embedded = (project in file("spark")) @@ -374,7 +382,7 @@ lazy val spark_embedded = (project in file("spark")) sparkBaseSettings, version := git.versionProperty.value, crossScalaVersions := supportedVersions, - libraryDependencies ++= fromMatrix(scalaVersion.value, "spark-all"), + libraryDependencies ++= fromMatrix(scalaVersion.value, "spark-all", "delta-core"), target := target.value.toPath.resolveSibling("target-embedded").toFile, Test / test := {} ) diff --git a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala index c054c0063..931efe608 100644 --- a/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala +++ b/spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala @@ -149,14 +149,7 @@ class ChrononKryoRegistrator extends KryoRegistrator { "scala.collection.immutable.ArraySeq$ofRef", "org.apache.spark.sql.catalyst.expressions.GenericInternalRow" ) - names.foreach { name => - try { - kryo.register(Class.forName(name)) - kryo.register(Class.forName(s"[L$name;")) // represents array of a type to jvm - } catch { - case _: ClassNotFoundException => // do nothing - } - } + names.foreach(name => doRegister(name, kryo)) kryo.register(classOf[Array[Array[Array[AnyRef]]]]) kryo.register(classOf[Array[Array[AnyRef]]]) @@ -164,4 +157,24 @@ class ChrononKryoRegistrator extends KryoRegistrator { kryo.register(classOf[Array[ItemSketchSerializable]]) kryo.register(classOf[ItemsSketchIR[AnyRef]], new ItemsSketchKryoSerializer[AnyRef]) } + + def doRegister(name: String, kryo: Kryo): Unit = { + try { + kryo.register(Class.forName(name)) + kryo.register(Class.forName(s"[L$name;")) // represents array of a type to jvm + } catch { + case _: ClassNotFoundException => // do nothing + } + } +} + +class ChrononDeltaLakeKryoRegistrator extends ChrononKryoRegistrator { + override def registerClasses(kryo: Kryo): Unit = { + super.registerClasses(kryo) + val additionalDeltaNames = Seq( + "org.apache.spark.sql.delta.stats.DeltaFileStatistics", + "org.apache.spark.sql.delta.actions.AddFile" + ) + additionalDeltaNames.foreach(name => doRegister(name, kryo)) + } } diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index a8bc2781f..cb6bb9c38 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -81,6 +81,9 @@ object Driver { default = Some(false), descr = "Skip the first unfilled partition range if some future partitions have been populated.") + val useDeltaCatalog: ScallopOption[Boolean] = + opt[Boolean](required = false, default = Some(false), descr = "Enable the use of the delta lake catalog") + val stepDays: ScallopOption[Int] = opt[Int](required = false, descr = "Runs offline backfill in steps, step-days at a time. Default is 30 days", @@ -136,8 +139,22 @@ object Driver { def isLocal: Boolean = localTableMapping.nonEmpty || localDataPath.isDefined protected def buildSparkSession(): SparkSession = { + // use of the delta lake catalog requires a couple of additional spark config options + val extraDeltaConfigs = useDeltaCatalog.toOption match { + case Some(true) => + Some( + Map( + "spark.sql.extensions" -> "io.delta.sql.DeltaSparkSessionExtension", + "spark.sql.catalog.spark_catalog" -> "org.apache.spark.sql.delta.catalog.DeltaCatalog" + )) + case _ => None + } + if (localTableMapping.nonEmpty) { - val localSession = SparkSessionBuilder.build(subcommandName(), local = true, localWarehouseLocation.toOption) + val localSession = SparkSessionBuilder.build(subcommandName(), + local = true, + localWarehouseLocation.toOption, + additionalConfig = extraDeltaConfigs) localTableMapping.foreach { case (table, filePath) => val file = new File(filePath) @@ -150,13 +167,16 @@ object Driver { val localSession = SparkSessionBuilder.build(subcommandName(), local = true, - localWarehouseLocation = localWarehouseLocation.toOption) + localWarehouseLocation = localWarehouseLocation.toOption, + additionalConfig = extraDeltaConfigs) LocalDataLoader.loadDataRecursively(dir, localSession) localSession } else { // We use the KryoSerializer for group bys and joins since we serialize the IRs. // But since staging query is fairly freeform, it's better to stick to the java serializer. - SparkSessionBuilder.build(subcommandName(), enforceKryoSerializer = !subcommandName().contains("staging_query")) + SparkSessionBuilder.build(subcommandName(), + enforceKryoSerializer = !subcommandName().contains("staging_query"), + additionalConfig = extraDeltaConfigs) } } diff --git a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala index dded015c2..ebe6f6b1a 100644 --- a/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala +++ b/spark/src/main/scala/ai/chronon/spark/SparkSessionBuilder.scala @@ -22,7 +22,6 @@ import org.apache.spark.SPARK_VERSION import java.io.File import java.util.logging.Logger -import scala.reflect.io.Path import scala.util.Properties object SparkSessionBuilder { @@ -30,6 +29,7 @@ object SparkSessionBuilder { private val warehouseId = java.util.UUID.randomUUID().toString.takeRight(6) private val DefaultWarehouseDir = new File("/tmp/chronon/spark-warehouse_" + warehouseId) + val FormatTestEnvVar: String = "format_test" def expandUser(path: String): String = path.replaceFirst("~", System.getProperty("user.home")) // we would want to share locally generated warehouse during CI testing @@ -38,6 +38,24 @@ object SparkSessionBuilder { localWarehouseLocation: Option[String] = None, additionalConfig: Option[Map[String, String]] = None, enforceKryoSerializer: Boolean = true): SparkSession = { + + // allow us to override the format by specifying env vars. This allows us to not have to worry about interference + // between Spark sessions created in existing chronon tests that need the hive format and some specific tests + // that require a format override like delta lake. + val (formatConfigs, kryoRegistrator) = sys.env.get(FormatTestEnvVar) match { + case Some("deltalake") => + val configMap = Map( + "spark.sql.extensions" -> "io.delta.sql.DeltaSparkSessionExtension", + "spark.sql.catalog.spark_catalog" -> "org.apache.spark.sql.delta.catalog.DeltaCatalog", + "spark.chronon.table_write.format" -> "delta" + ) + (configMap, "ai.chronon.spark.ChrononDeltaLakeKryoRegistrator") + case _ => (Map.empty, "ai.chronon.spark.ChrononKryoRegistrator") + } + + // tack on format configs with additional configs + val mergedConfigs = additionalConfig.getOrElse(Map.empty) ++ formatConfigs + if (local) { //required to run spark locally with hive support enabled - for sbt test System.setSecurityManager(null) @@ -61,13 +79,12 @@ object SparkSessionBuilder { if (enforceKryoSerializer) { baseBuilder .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .config("spark.kryo.registrator", "ai.chronon.spark.ChrononKryoRegistrator") + .config("spark.kryo.registrator", kryoRegistrator) .config("spark.kryoserializer.buffer.max", "2000m") .config("spark.kryo.referenceTracking", "false") } - additionalConfig.foreach { configMap => - configMap.foreach { config => baseBuilder = baseBuilder.config(config._1, config._2) } - } + + mergedConfigs.foreach { config => baseBuilder = baseBuilder.config(config._1, config._2) } if (SPARK_VERSION.startsWith("2")) { // Otherwise files left from deleting the table with the same name result in test failures diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 01f7c5950..483b1b552 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -16,18 +16,20 @@ package ai.chronon.spark -import java.io.{PrintWriter, StringWriter} +import java.io.{PrintWriter, Serializable, StringWriter} import org.slf4j.LoggerFactory import ai.chronon.aggregator.windowing.TsUtils import ai.chronon.api.{Constants, PartitionSpec} import ai.chronon.api.Extensions._ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import ai.chronon.spark.Extensions.{DfStats, DfWithStats} +import io.delta.tables.DeltaTable import jnr.ffi.annotations.Synchronized import org.apache.hadoop.hive.metastore.api.AlreadyExistsException import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} +import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession} @@ -41,6 +43,237 @@ import scala.collection.immutable import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} import scala.util.{Failure, Success, Try} +/** + * Trait to track the table format in use by a Chronon dataset and some utility methods to help + * retrieve metadata / configure it appropriately at creation time + */ +trait Format { + // Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided + // If subpartition filters are supplied and the format doesn't support it, we throw an error + def primaryPartitions(tableName: String, + partitionColumn: String, + subPartitionsFilter: Map[String, String] = Map.empty)(implicit + sparkSession: SparkSession): Seq[String] = { + if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) { + throw new NotImplementedError(s"subPartitionsFilter is not supported on this format") + } + + val partitionSeq = partitions(tableName)(sparkSession) + partitionSeq.flatMap { partitionMap => + if ( + subPartitionsFilter.forall { + case (k, v) => partitionMap.get(k).contains(v) + } + ) { + partitionMap.get(partitionColumn) + } else { + None + } + } + } + + // Return a sequence for partitions where each partition entry consists of a Map of partition keys to values + def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] + + // Help specify the appropriate table type to use in the Spark create table DDL query + def createTableTypeString: String + + // Help specify the appropriate file format to use in the Spark create table DDL query + def fileFormatString(format: String): String + + // Does this format support sub partitions filters + def supportSubPartitionsFilter: Boolean +} + +/** + * Dynamically provide the read / write table format depending on table name. + * This supports reading/writing tables with heterogeneous formats. + * This approach enables users to override and specify a custom format provider if needed. This is useful in + * cases such as leveraging different library versions from what we support in the Chronon project (e.g. newer delta lake) + * as well as working with custom internal company logic / checks. + */ +trait FormatProvider extends Serializable { + def readFormat(tableName: String): Format + def writeFormat(tableName: String): Format +} + +/** + * Default format provider implementation based on default Chronon supported open source library versions. + */ +case class DefaultFormatProvider(sparkSession: SparkSession) extends FormatProvider { + @transient lazy val logger = LoggerFactory.getLogger(getClass) + + // Checks the format of a given table by checking the format it's written out as + override def readFormat(tableName: String): Format = { + if (isIcebergTable(tableName)) { + Iceberg + } else if (isDeltaTable(tableName)) { + DeltaLake + } else { + Hive + } + } + + private def isIcebergTable(tableName: String): Boolean = + Try { + sparkSession.read.format("iceberg").load(tableName) + } match { + case Success(_) => + logger.info(s"IcebergCheck: Detected iceberg formatted table $tableName.") + true + case _ => + logger.info(s"IcebergCheck: Checked table $tableName is not iceberg format.") + false + } + + private def isDeltaTable(tableName: String): Boolean = { + Try { + val describeResult = sparkSession.sql(s"DESCRIBE DETAIL $tableName") + describeResult.select("format").first().getString(0).toLowerCase + } match { + case Success(format) => + logger.info(s"Delta check: Successfully read the format of table: $tableName as $format") + format == "delta" + case _ => + // the describe detail calls fails for Delta Lake tables + logger.info(s"Delta check: Unable to read the format of the table $tableName using DESCRIBE DETAIL") + false + } + } + + // Return the write format to use for the given table. The logic at a high level is: + // 1) If the user specifies the spark.chronon.table_write.iceberg - we go with Iceberg + // 2) If the user specifies a spark.chronon.table_write.format as Hive (parquet), Iceberg or Delta we go with their choice + // 3) Default to Hive (parquet) + // Note the table_write.iceberg is supported for legacy reasons. Specifying "iceberg" in spark.chronon.table_write.format + // is preferred as the latter conf also allows us to support additional formats + override def writeFormat(tableName: String): Format = { + val useIceberg: Boolean = sparkSession.conf.get("spark.chronon.table_write.iceberg", "false").toBoolean + + // Default provider just looks for any default config. + // Unlike read table, these write tables might not already exist. + val maybeFormat = sparkSession.conf.getOption("spark.chronon.table_write.format").map(_.toLowerCase) match { + case Some("hive") => Some(Hive) + case Some("iceberg") => Some(Iceberg) + case Some("delta") => Some(DeltaLake) + case _ => None + } + (useIceberg, maybeFormat) match { + // if explicitly configured Iceberg - we go with that setting + case (true, _) => Iceberg + // else if there is a write format we pick that + case (false, Some(format)) => format + // fallback to hive (parquet) + case (false, None) => Hive + } + } +} + +case object Hive extends Format { + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( + implicit sparkSession: SparkSession): Seq[String] = + super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) + + def parseHivePartition(pstring: String): Map[String, String] = { + pstring + .split("/") + .map { part => + val p = part.split("=", 2) + p(0) -> p(1) + } + .toMap + } + + override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = { + // data is structured as a Df with single composite partition key column. Every row is a partition with the + // column values filled out as a formatted key=value pair + // Eg. df schema = (partitions: String) + // rows = [ "day=2020-10-10/hour=00", ... ] + sparkSession.sqlContext + .sql(s"SHOW PARTITIONS $tableName") + .collect() + .map(row => parseHivePartition(row.getString(0))) + } + + def createTableTypeString: String = "" + def fileFormatString(format: String): String = s"STORED AS $format" + + override def supportSubPartitionsFilter: Boolean = true +} + +case object Iceberg extends Format { + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( + implicit sparkSession: SparkSession): Seq[String] = { + if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) { + throw new NotImplementedError(s"subPartitionsFilter is not supported on this format") + } + + getIcebergPartitions(tableName) + } + + override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = { + throw new NotImplementedError( + "Multi-partitions retrieval is not supported on Iceberg tables yet." + + "For single partition retrieval, please use 'partition' method.") + } + + private def getIcebergPartitions(tableName: String)(implicit sparkSession: SparkSession): Seq[String] = { + val partitionsDf = sparkSession.read.format("iceberg").load(s"$tableName.partitions") + val index = partitionsDf.schema.fieldIndex("partition") + if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("hr")) { + // Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718 + // so we collect and then filter. + partitionsDf + .select("partition.ds", "partition.hr") + .collect() + .filter(_.get(1) == null) + .map(_.getString(0)) + .toSeq + } else { + partitionsDf + .select("partition.ds") + .collect() + .map(_.getString(0)) + .toSeq + } + } + + def createTableTypeString: String = "USING iceberg" + def fileFormatString(format: String): String = "" + + override def supportSubPartitionsFilter: Boolean = false +} + +// The Delta Lake format is compatible with the Delta lake and Spark versions currently supported by the project. +// Attempting to use newer Delta lake library versions (e.g. 3.2 which works with Spark 3.5) results in errors: +// java.lang.NoSuchMethodError: 'org.apache.spark.sql.delta.Snapshot org.apache.spark.sql.delta.DeltaLog.update(boolean)' +// In such cases, you should implement your own FormatProvider built on the newer Delta lake version +case object DeltaLake extends Format { + + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( + implicit sparkSession: SparkSession): Seq[String] = + super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) + + override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = { + // delta lake doesn't support the `SHOW PARTITIONS ` syntax - https://github.com/delta-io/delta/issues/996 + // there's alternative ways to retrieve partitions using the DeltaLog abstraction which is what we have to lean into + // below + // first pull table location as that is what we need to pass to the delta log + val describeResult = sparkSession.sql(s"DESCRIBE DETAIL $tableName") + val tablePath = describeResult.select("location").head().getString(0) + + val snapshot = DeltaLog.forTable(sparkSession, tablePath).update() + val snapshotPartitionsDf = snapshot.allFiles.toDF().select("partitionValues") + val partitions = snapshotPartitionsDf.collect().map(r => r.getAs[Map[String, String]](0)) + partitions + } + + def createTableTypeString: String = "USING DELTA" + def fileFormatString(format: String): String = "" + + override def supportSubPartitionsFilter: Boolean = true +} + case class TableUtils(sparkSession: SparkSession) { @transient lazy val logger = LoggerFactory.getLogger(getClass) @@ -69,7 +302,16 @@ case class TableUtils(sparkSession: SparkSession) { val blockingCacheEviction: Boolean = sparkSession.conf.get("spark.chronon.table_write.cache.blocking", "false").toBoolean - val useIceberg: Boolean = sparkSession.conf.get("spark.chronon.table_write.iceberg", "false").toBoolean + private lazy val tableFormatProvider: FormatProvider = { + sparkSession.conf.getOption("spark.chronon.table.format_provider") match { + case Some(clazzName) => + // Load object instead of class/case class + Class.forName(clazzName).getField("MODULE$").get(null).asInstanceOf[FormatProvider] + case None => + DefaultFormatProvider(sparkSession) + } + } + val cacheLevel: Option[StorageLevel] = Try { if (cacheLevelString == "NONE") None else Some(StorageLevel.fromString(cacheLevelString)) @@ -101,16 +343,6 @@ case class TableUtils(sparkSession: SparkSession) { rdd } - def parsePartition(pstring: String): Map[String, String] = { - pstring - .split("/") - .map { part => - val p = part.split("=", 2) - p(0) -> p(1) - } - .toMap - } - def tableExists(tableName: String): Boolean = sparkSession.catalog.tableExists(tableName) def loadEntireTable(tableName: String): DataFrame = sparkSession.table(tableName) @@ -136,87 +368,26 @@ case class TableUtils(sparkSession: SparkSession) { } } + def tableReadFormat(tableName: String): Format = tableFormatProvider.readFormat(tableName) + // return all specified partition columns in a table in format of Map[partitionName, PartitionValue] def allPartitions(tableName: String, partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = { if (!tableExists(tableName)) return Seq.empty[Map[String, String]] - if (isIcebergTable(tableName)) { - throw new NotImplementedError( - "Multi-partitions retrieval is not supported on Iceberg tables yet." + - "For single partition retrieval, please use 'partition' method.") - } - sparkSession.sqlContext - .sql(s"SHOW PARTITIONS $tableName") - .collect() - .map { row => - { - val partitionMap = parsePartition(row.getString(0)) - if (partitionColumnsFilter.isEmpty) { - partitionMap - } else { - partitionMap.filterKeys(key => partitionColumnsFilter.contains(key)).toMap - } - } + val format = tableReadFormat(tableName) + val partitionSeq = format.partitions(tableName)(sparkSession) + if (partitionColumnsFilter.isEmpty) { + partitionSeq + } else { + partitionSeq.map { partitionMap => + partitionMap.filterKeys(key => partitionColumnsFilter.contains(key)).toMap } + } } def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = { if (!tableExists(tableName)) return Seq.empty[String] - if (isIcebergTable(tableName)) { - if (subPartitionsFilter.nonEmpty) { - throw new NotImplementedError("subPartitionsFilter is not supported on Iceberg tables yet.") - } - return getIcebergPartitions(tableName) - } - sparkSession.sqlContext - .sql(s"SHOW PARTITIONS $tableName") - .collect() - .flatMap { row => - { - val partitionMap = parsePartition(row.getString(0)) - if ( - subPartitionsFilter.forall { - case (k, v) => partitionMap.get(k).contains(v) - } - ) { - partitionMap.get(partitionColumn) - } else { - None - } - } - } - } - - private def isIcebergTable(tableName: String): Boolean = - Try { - sparkSession.read.format("iceberg").load(tableName) - } match { - case Success(_) => - logger.info(s"IcebergCheck: Detected iceberg formatted table $tableName.") - true - case _ => - logger.info(s"IcebergCheck: Checked table $tableName is not iceberg format.") - false - } - - private def getIcebergPartitions(tableName: String): Seq[String] = { - val partitionsDf = sparkSession.read.format("iceberg").load(s"$tableName.partitions") - val index = partitionsDf.schema.fieldIndex("partition") - if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("hr")) { - // Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718 - // so we collect and then filter. - partitionsDf - .select("partition.ds", "partition.hr") - .collect() - .filter(_.get(1) == null) - .map(_.getString(0)) - .toSeq - } else { - partitionsDf - .select("partition.ds") - .collect() - .map(_.getString(0)) - .toSeq - } + val format = tableReadFormat(tableName) + format.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)(sparkSession) } // Given a table and a query extract the schema of the columns involved as input. @@ -519,11 +690,10 @@ case class TableUtils(sparkSession: SparkSession) { .filterNot(field => partitionColumns.contains(field.name)) .map(field => s"`${field.name}` ${field.dataType.catalogString}") - val tableTypString = if (useIceberg) { - "USING iceberg" - } else { - "" - } + val writeFormat = tableFormatProvider.writeFormat(tableName) + + val tableTypString = writeFormat.createTableTypeString + val createFragment = s"""CREATE TABLE $tableName ( | ${fieldDefinitions.mkString(",\n ")} @@ -545,11 +715,7 @@ case class TableUtils(sparkSession: SparkSession) { } else { "" } - val fileFormatString = if (useIceberg) { - "" - } else { - s"STORED AS $fileFormat" - } + val fileFormatString = writeFormat.fileFormatString(fileFormat) Seq(createFragment, partitionFragment, fileFormatString, propertiesFragment).mkString("\n") } diff --git a/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala new file mode 100644 index 000000000..d4c8b806a --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala @@ -0,0 +1,192 @@ +package ai.chronon.spark.test + +import ai.chronon.api.{DoubleType, IntType, LongType, StringType, StructField, StructType} +import ai.chronon.spark.SparkSessionBuilder.FormatTestEnvVar +import ai.chronon.spark.test.TestUtils.makeDf +import ai.chronon.spark.{IncompatibleSchemaException, SparkSessionBuilder, TableUtils} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.junit.Assert.{assertEquals, assertTrue} +import org.junit.Test + +import scala.util.Try + +class TableUtilsFormatTest { + + import TableUtilsFormatTest._ + + // Read the format we want this instantiation of the test to run via environment vars + val format: String = sys.env.getOrElse(FormatTestEnvVar, "hive") + val spark = SparkSessionBuilder.build("TableUtilsFormatTest", local = true) + val tableUtils = TableUtils(spark) + + @Test + def testInsertPartitionsAddColumns(): Unit = { + val dbName = s"db_${System.currentTimeMillis()}" + val tableName = s"$dbName.test_table_1_$format" + spark.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") + val columns1 = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("string_field", StringType) + ) + val df1 = makeDf( + spark, + StructType( + tableName, + columns1 :+ StructField("ds", StringType) + ), + List( + Row(1L, 2, "3", "2022-10-01") + ) + ) + + val df2 = makeDf( + spark, + StructType( + tableName, + columns1 + :+ StructField("double_field", DoubleType) + :+ StructField("ds", StringType) + ), + List( + Row(4L, 5, "6", 7.0, "2022-10-02") + ) + ) + testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") + } + + @Test + def testInsertPartitionsAddRemoveColumns(): Unit = { + val dbName = s"db_${System.currentTimeMillis()}" + val tableName = s"$dbName.test_table_2_$format" + spark.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") + val columns1 = Array( + StructField("long_field", LongType), + StructField("int_field", IntType), + StructField("string_field", StringType) + ) + val df1 = makeDf( + spark, + StructType( + tableName, + columns1 + :+ StructField("double_field", DoubleType) + :+ StructField("ds", StringType) + ), + List( + Row(1L, 2, "3", 4.0, "2022-10-01") + ) + ) + + val df2 = makeDf( + spark, + StructType( + tableName, + columns1 :+ StructField("ds", StringType) + ), + List( + Row(5L, 6, "7", "2022-10-02") + ) + ) + testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") + } + + @Test + def testInsertPartitionsAddModifyColumns(): Unit = { + val dbName = s"db_${System.currentTimeMillis()}" + val tableName = s"$dbName.test_table_3_$format" + spark.sql(s"CREATE DATABASE IF NOT EXISTS $dbName") + val columns1 = Array( + StructField("long_field", LongType), + StructField("int_field", IntType) + ) + val df1 = makeDf( + spark, + StructType( + tableName, + columns1 + :+ StructField("string_field", StringType) + :+ StructField("ds", StringType) + ), + List( + Row(1L, 2, "3", "2022-10-01") + ) + ) + + val df2 = makeDf( + spark, + StructType( + tableName, + columns1 + :+ StructField("string_field", DoubleType) // modified column data type + :+ StructField("ds", StringType) + ), + List( + Row(1L, 2, 3.0, "2022-10-02") + ) + ) + + testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = "2022-10-01", ds2 = "2022-10-02") + } +} + +object TableUtilsFormatTest { + private def testInsertPartitions(spark: SparkSession, + tableUtils: TableUtils, + tableName: String, + format: String, + df1: DataFrame, + df2: DataFrame, + ds1: String, + ds2: String): Unit = { + tableUtils.insertPartitions(df1, tableName, autoExpand = true) + val addedColumns = df2.schema.fieldNames.filterNot(df1.schema.fieldNames.contains) + val removedColumns = df1.schema.fieldNames.filterNot(df2.schema.fieldNames.contains) + val inconsistentColumns = ( + for ( + (name1, dtype1) <- df1.schema.fields.map(structField => (structField.name, structField.dataType)); + (name2, dtype2) <- df2.schema.fields.map(structField => (structField.name, structField.dataType)) + ) yield { + name1 == name2 && dtype1 != dtype2 + } + ).filter(identity) + + if (inconsistentColumns.nonEmpty) { + val insertTry = Try(tableUtils.insertPartitions(df2, tableName, autoExpand = true)) + val e = insertTry.failed.get.asInstanceOf[IncompatibleSchemaException] + assertEquals(inconsistentColumns.length, e.inconsistencies.length) + return + } + + if (df2.schema != df1.schema) { + val insertTry = Try(tableUtils.insertPartitions(df2, tableName)) + assertTrue(insertTry.failed.get.isInstanceOf[AnalysisException]) + } + + tableUtils.insertPartitions(df2, tableName, autoExpand = true) + + // check that we wrote out a table in the right format + val readTableFormat = tableUtils.tableReadFormat(tableName).toString + assertTrue(s"Mismatch in table format: $readTableFormat; expected: $format", readTableFormat.toLowerCase == format) + + // check we have all the partitions written + val returnedPartitions = tableUtils.partitions(tableName) + assertTrue(returnedPartitions.toSet == Set(ds1, ds2)) + + val dataRead1 = spark.table(tableName).where(col("ds") === ds1) + val dataRead2 = spark.table(tableName).where(col("ds") === ds2) + assertTrue(dataRead1.columns.length == dataRead2.columns.length) + + val totalColumnsCount = (df1.schema.fieldNames.toSet ++ df2.schema.fieldNames.toSet).size + assertEquals(totalColumnsCount, dataRead1.columns.length) + assertEquals(totalColumnsCount, dataRead2.columns.length) + + addedColumns.foreach(col => { + dataRead1.foreach(row => assertTrue(Option(row.getAs[Any](col)).isEmpty)) + }) + removedColumns.foreach(col => { + dataRead2.foreach(row => assertTrue(Option(row.getAs[Any](col)).isEmpty)) + }) + } +}