Skip to content

Commit c1df43b

Browse files
support format provider
add thrift 18 using user format foo more log add log lines morel ong lines only use delta format need to commit try add duck wip wip add log line provider revert thrift
1 parent 8f2c70d commit c1df43b

File tree

2 files changed

+58
-31
lines changed

2 files changed

+58
-31
lines changed

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ trait Format {
5858
def fileFormatString(format: String): String
5959
}
6060

61+
/**
62+
* [DPRTI-492][OAI_CHANGES]: Dynamically provide table Format depending on table name.
63+
* This supports reading/writing tables with heterogeneous formats.
64+
*/
65+
trait FormatProvider {
66+
def get(tableName: String): Format
67+
}
68+
6169
case object Hive extends Format {
6270
def parseHivePartition(pstring: String): Map[String, String] = {
6371
pstring
@@ -144,14 +152,52 @@ case class TableUtils(sparkSession: SparkSession) {
144152

145153
val useIceberg: Boolean = sparkSession.conf.get("spark.chronon.table_write.iceberg", "false").toBoolean
146154

147-
// write data using the relevant supported Chronon write format
148-
val maybeWriteFormat: Option[Format] =
149-
sparkSession.conf.getOption("spark.chronon.table_write.format").map(_.toLowerCase) match {
150-
case Some("hive") => Some(Hive)
151-
case Some("iceberg") => Some(Iceberg)
152-
case Some("delta") => Some(DeltaLake)
153-
case _ => None
155+
// [DPRTI-492][OAI_CHANGES]: Although Chronon OSS supports (WIP) delta, there's delta version mismatch causing
156+
// runtime: java.lang.NoSuchMethodError: 'org.apache.spark.sql.delta.Snapshot org.apache.spark.sql.delta.DeltaLog.update(boolean)'
157+
// In OAI delta, DeltaLog.update(boolean, Option[Long]). Therefore, we have to supply our own delta format
158+
// compiled with the delta version we use.
159+
private lazy val tableReadFormatProvider: FormatProvider = {
160+
sparkSession.conf.getOption("spark.chronon.table_read.format_provider") match {
161+
case Some(clazzName) =>
162+
// Load object instead of class/case class
163+
Class.forName(clazzName).getField("MODULE$").get(null).asInstanceOf[FormatProvider]
164+
case None => (tableName: String) => {
165+
if (isIcebergTable(tableName)) {
166+
Iceberg
167+
} else if (isDeltaTable(tableName)) {
168+
DeltaLake
169+
} else {
170+
Hive
171+
}
172+
}
173+
}
174+
}
175+
176+
private lazy val tableWriteFormatProvider: FormatProvider = {
177+
sparkSession.conf.getOption("spark.chronon.table_write.format_provider") match {
178+
case Some(clazzName) =>
179+
val clazz = Class.forName(clazzName)
180+
clazz.getField("MODULE$").get(null).asInstanceOf[FormatProvider]
181+
case None => (_: String) => {
182+
// Default provider just looks for any default config.
183+
// Unlike read table, these write tables might not already exist.
184+
val maybeFormat = sparkSession.conf.getOption("spark.chronon.default_table_write.format").map(_.toLowerCase) match {
185+
case Some("hive") => Some(Hive)
186+
case Some("iceberg") => Some(Iceberg)
187+
case Some("delta") => Some(DeltaLake)
188+
case _ => None
189+
}
190+
(useIceberg, maybeFormat) match {
191+
// if explicitly configured Iceberg - we go with that setting
192+
case (true, _) => Iceberg
193+
// else if there is a write format we pick that
194+
case (false, Some(format)) => format
195+
// fallback to hive (parquet)
196+
case (false, None) => Hive
197+
}
198+
}
154199
}
200+
}
155201

156202
val cacheLevel: Option[StorageLevel] = Try {
157203
if (cacheLevelString == "NONE") None
@@ -207,20 +253,12 @@ case class TableUtils(sparkSession: SparkSession) {
207253
}
208254
}
209255

210-
def tableFormat(tableName: String): Format = {
211-
if (isIcebergTable(tableName)) {
212-
Iceberg
213-
} else if (isDeltaTable(tableName)) {
214-
DeltaLake
215-
} else {
216-
Hive
217-
}
218-
}
256+
def tableReadFormat(tableName: String): Format = tableReadFormatProvider.get(tableName)
219257

220258
// return all specified partition columns in a table in format of Map[partitionName, PartitionValue]
221259
def allPartitions(tableName: String, partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = {
222260
if (!tableExists(tableName)) return Seq.empty[Map[String, String]]
223-
val format = tableFormat(tableName)
261+
val format = tableReadFormat(tableName)
224262
val partitionSeq = format.partitions(tableName)(sparkSession)
225263
if (partitionColumnsFilter.isEmpty) {
226264
partitionSeq
@@ -233,7 +271,7 @@ case class TableUtils(sparkSession: SparkSession) {
233271

234272
def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
235273
if (!tableExists(tableName)) return Seq.empty[String]
236-
val format = tableFormat(tableName)
274+
val format = tableReadFormat(tableName)
237275

238276
if (format == Iceberg) {
239277
if (subPartitionsFilter.nonEmpty) {
@@ -595,17 +633,6 @@ case class TableUtils(sparkSession: SparkSession) {
595633
}
596634
}
597635

598-
protected[spark] def getWriteFormat: Format = {
599-
(useIceberg, maybeWriteFormat) match {
600-
// if explicitly configured Iceberg - we go with that setting
601-
case (true, _) => Iceberg
602-
// else if there is a write format we pick that
603-
case (false, Some(format)) => format
604-
// fallback to hive (parquet)
605-
case (false, None) => Hive
606-
}
607-
}
608-
609636
private def createTableSql(tableName: String,
610637
schema: StructType,
611638
partitionColumns: Seq[String],
@@ -615,7 +642,7 @@ case class TableUtils(sparkSession: SparkSession) {
615642
.filterNot(field => partitionColumns.contains(field.name))
616643
.map(field => s"`${field.name}` ${field.dataType.catalogString}")
617644

618-
val writeFormat = getWriteFormat
645+
val writeFormat = tableWriteFormatProvider.get(tableName)
619646

620647
logger.info(
621648
s"Choosing format: $writeFormat based on useIceberg flag = $useIceberg and " +

spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ object TableUtilsFormatTest {
164164
tableUtils.insertPartitions(df2, tableName, autoExpand = true)
165165

166166
// check that we wrote out a table in the right format
167-
val readTableFormat = tableUtils.tableFormat(tableName).toString
167+
val readTableFormat = tableUtils.tableReadFormat(tableName).toString
168168
assertTrue(s"Mismatch in table format: $readTableFormat; expected: $format", readTableFormat.toLowerCase == format)
169169

170170
// check we have all the partitions written

0 commit comments

Comments
 (0)