Skip to content

Commit 5f393cb

Browse files
fix: supply partition column only when needed
Co-authored-by: Thomas Chow <[email protected]>
1 parent dd68dc3 commit 5f393cb

File tree

6 files changed

+40
-31
lines changed

6 files changed

+40
-31
lines changed

api/src/main/scala/ai/chronon/api/DataPointer.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@ abstract class DataPointer {
55
def tableOrPath: String
66
def readFormat: Option[String]
77
def writeFormat: Option[String]
8-
def options: Map[String, String]
8+
9+
def readOptions: Map[String, String]
10+
def writeOptions: Map[String, String]
911

1012
}
1113

1214
case class URIDataPointer(
1315
override val tableOrPath: String,
1416
override val readFormat: Option[String],
1517
override val writeFormat: Option[String],
16-
override val options: Map[String, String]
17-
) extends DataPointer
18+
options: Map[String, String]
19+
) extends DataPointer {
20+
21+
override val readOptions: Map[String, String] = options
22+
override val writeOptions: Map[String, String] = options
23+
}
1824

1925
// parses string representations of data pointers
2026
// ex: namespace.table

api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ import ai.chronon.api.URIDataPointer
55
import org.scalatest.flatspec.AnyFlatSpec
66
import org.scalatest.matchers.should.Matchers
77

8-
class
9-
DataPointerTest extends AnyFlatSpec with Matchers {
8+
class DataPointerTest extends AnyFlatSpec with Matchers {
109

1110
"DataPointer.apply" should "parse a simple s3 path" in {
1211
val result = DataPointer("s3://bucket/path/to/data.parquet")

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
4040
// Fixed to BigQuery for now.
4141
override def writeFormat(tableName: String): Format = {
4242

43+
val tu = TableUtils(sparkSession)
44+
val partitionColumnOption =
45+
if (tu.tableReachable(tableName)) Map.empty else Map("partitionField" -> tu.partitionColumn)
46+
4347
val sparkOptions: Map[String, String] =
4448
Map(
4549
"temporaryGcsBucket" -> sparkSession.conf.get(
4650
"spark.chronon.table.gcs.temporary_gcs_bucket"
4751
), // todo(tchow): No longer needed after https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/1320
4852
"writeMethod" -> "indirect"
49-
)
53+
) ++ partitionColumnOption
5054
BQuery(bqOptions.getProjectId, sparkOptions)
5155
}
5256

spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For
1010
override def tableOrPath: String = {
1111
formatProvider.resolveTableName(inputTableOrPath)
1212
}
13-
override lazy val options: Map[String, String] = {
14-
// Hack for now, include both read and write options for the datapointer.
15-
// todo(tchow): rework this abstraction. https://app.asana.com/0/1208785567265389/1209026103291854/f
16-
formatProvider.readFormat(inputTableOrPath).options ++ formatProvider.writeFormat(inputTableOrPath).options
13+
14+
override lazy val readOptions: Map[String, String] = {
15+
formatProvider.readFormat(inputTableOrPath).options
16+
}
17+
18+
override lazy val writeOptions: Map[String, String] = {
19+
formatProvider.writeFormat(inputTableOrPath).options
1720
}
1821

1922
override lazy val readFormat: Option[String] = {
@@ -28,7 +31,7 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For
2831

2932
object DataPointer {
3033

31-
def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
34+
def from(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
3235
val clazzName =
3336
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
3437
val mirror = runtimeMirror(getClass.getClassLoader)

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -304,28 +304,26 @@ object Extensions {
304304

305305
def save(dataPointer: DataPointer): Unit = {
306306

307+
val optionDfw = dfw.options(dataPointer.writeOptions)
307308
dataPointer.writeFormat
308309
.map((wf) => {
309310
val normalized = wf.toLowerCase
310311
normalized match {
311312
case "bigquery" | "bq" =>
312-
dfw
313+
optionDfw
313314
.format("bigquery")
314-
.options(dataPointer.options)
315315
.save(dataPointer.tableOrPath)
316316
case "snowflake" | "sf" =>
317-
dfw
317+
optionDfw
318318
.format("net.snowflake.spark.snowflake")
319-
.options(dataPointer.options)
320319
.option("dbtable", dataPointer.tableOrPath)
321320
.save()
322321
case "parquet" | "csv" =>
323-
dfw
322+
optionDfw
324323
.format(normalized)
325-
.options(dataPointer.options)
326324
.save(dataPointer.tableOrPath)
327325
case "hive" | "delta" | "iceberg" =>
328-
dfw
326+
optionDfw
329327
.format(normalized)
330328
.insertInto(dataPointer.tableOrPath)
331329
case _ =>
@@ -334,7 +332,7 @@ object Extensions {
334332
})
335333
.getOrElse(
336334
// None case is just table against default catalog
337-
dfw
335+
optionDfw
338336
.format("hive")
339337
.insertInto(dataPointer.tableOrPath))
340338
}
@@ -345,25 +343,24 @@ object Extensions {
345343
def load(dataPointer: DataPointer): DataFrame = {
346344
val tableOrPath = dataPointer.tableOrPath
347345

346+
val optionDfr = dfr.options(dataPointer.readOptions)
347+
348348
dataPointer.readFormat
349349
.map((fmt) => {
350350
val normalized = fmt.toLowerCase
351351
normalized match {
352352
case "bigquery" | "bq" =>
353-
dfr
353+
optionDfr
354354
.format("bigquery")
355-
.options(dataPointer.options)
356355
.load(tableOrPath)
357356
case "snowflake" | "sf" =>
358-
dfr
357+
optionDfr
359358
.format("net.snowflake.spark.snowflake")
360-
.options(dataPointer.options)
361359
.option("dbtable", tableOrPath)
362360
.load()
363361
case "parquet" | "csv" =>
364-
dfr
362+
optionDfr
365363
.format(normalized)
366-
.options(dataPointer.options)
367364
.load(tableOrPath)
368365
case "hive" | "delta" | "iceberg" => dfr.table(tableOrPath)
369366
case _ =>
@@ -372,7 +369,7 @@ object Extensions {
372369
})
373370
.getOrElse {
374371
// None case is just table against default catalog
375-
dfr.table(tableOrPath)
372+
optionDfr.table(tableOrPath)
376373
}
377374
}
378375
}

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
139139

140140
// Needs provider
141141
def loadTable(tableName: String): DataFrame = {
142-
sparkSession.read.load(DataPointer(tableName, sparkSession))
142+
sparkSession.read.load(DataPointer.from(tableName, sparkSession))
143143
}
144144

145145
def isPartitioned(tableName: String): Boolean = {
@@ -243,7 +243,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
243243
}
244244

245245
def getSchemaFromTable(tableName: String): StructType = {
246-
sparkSession.read.load(DataPointer(tableName, sparkSession)).limit(1).schema
246+
sparkSession.read.load(DataPointer.from(tableName, sparkSession)).limit(1).schema
247247
}
248248

249249
// method to check if a user has access to a table
@@ -256,7 +256,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
256256
// retrieve one row from the table
257257
val partitionFilter = lastAvailablePartition(tableName).getOrElse(fallbackPartition)
258258
sparkSession.read
259-
.load(DataPointer(tableName, sparkSession))
259+
.load(DataPointer.from(tableName, sparkSession))
260260
.where(s"$partitionColumn='$partitionFilter'")
261261
.limit(1)
262262
.collect()
@@ -533,7 +533,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
533533
(Seq(partitionColumn, saltCol), Seq(partitionColumn) ++ sortByCols)
534534
} else { (Seq(saltCol), sortByCols) }
535535
logger.info(s"Sorting within partitions with cols: $partitionSortCols")
536-
val dataPointer = DataPointer(tableName, sparkSession)
536+
val dataPointer = DataPointer.from(tableName, sparkSession)
537537

538538
val dfw = saltedDf
539539
.select(saltedDf.columns.map {
@@ -799,7 +799,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
799799
wheres: Seq[String],
800800
rangeWheres: Seq[String],
801801
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {
802-
val dp = DataPointer(table, sparkSession)
802+
val dp = DataPointer.from(table, sparkSession)
803803
var df = sparkSession.read.load(dp)
804804
val selects = QueryUtils.buildSelects(selectMap, fallbackSelects)
805805
logger.info(s""" Scanning data:

0 commit comments

Comments
 (0)