Skip to content

Commit a7a9a42

Browse files
fix: handle partition overwrite (#206)
## Summary Based on [@david Han](https://zipline-2kh4520.slack.com/team/U0846REC8F7)’s observation's around the spark-bigquery [connector](https://zipline-2kh4520.slack.com/archives/C08710CDH8D/p1736700947319249?thread_ts=1736644291.357239&cid=C08710CDH8D), there indeed is a lurking behavior. when creating BQ tables (in the case they don't exist), user needs to specify a partitioning. This is expected. We do that in the form of a partitionColumn write option. when the connector performs dynamic partition overwrites, you don't need to specify the partitioning at all. It will do the right thing because the destination table was already created with a partition spec. ^ ideally, the above could be idempotent even if the user passes the partition column to the write, but unfortunately it's a strict requirement that you don't define the partition column when doing dynamic partition overwrites. Fix is to specify the partition column only when the table DNE and needs to be created, and leave it out in all other cases. ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced separate read and write options for data pointers - Added support for Google Cloud SDK tool - Added new plugins for `gcloud` and `thrift` - **Bug Fixes** - Enhanced error handling in BigQuery format provider - **Refactor** - Standardized `DataPointer` instantiation method from `apply` to `from` - Improved options handling in data operations - **Chores** - Updated plugin and tool versions for development environment <!-- end of auto-generated comment: release notes by coderabbit.ai --> --- - To see the specific tasks where the Asana app for GitHub is being used, see below: - https://app.asana.com/0/0/1209143482009688 <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> --------- Co-authored-by: Thomas Chow <[email protected]>
1 parent 7902121 commit a7a9a42

File tree

8 files changed

+51
-38
lines changed

8 files changed

+51
-38
lines changed

.plugin-versions

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
asdf-plugin-manager https://github.com/asdf-community/asdf-plugin-manager.git b5862c1
2+
gcloud https://github.com/jthegedus/asdf-gcloud.git 00cdf06
23
java https://github.com/halcyon/asdf-java.git 0ec69b2
34
python https://github.com/danhper/asdf-python.git a3a0185
45
sbt https://github.com/lerencao/asdf-sbt 53c9f4b
56
scala https://github.com/asdf-community/asdf-scala.git 0533444
7+
thrift https://github.com/alisaifee/asdf-thrift.git fecdd6c

.tool-versions

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ sbt 1.8.2
55
python
66
3.7.17
77
3.11.0
8+
gcloud 504.0.1

api/src/main/scala/ai/chronon/api/DataPointer.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@ abstract class DataPointer {
55
def tableOrPath: String
66
def readFormat: Option[String]
77
def writeFormat: Option[String]
8-
def options: Map[String, String]
8+
9+
def readOptions: Map[String, String]
10+
def writeOptions: Map[String, String]
911

1012
}
1113

1214
case class URIDataPointer(
1315
override val tableOrPath: String,
1416
override val readFormat: Option[String],
1517
override val writeFormat: Option[String],
16-
override val options: Map[String, String]
17-
) extends DataPointer
18+
options: Map[String, String]
19+
) extends DataPointer {
20+
21+
override val readOptions: Map[String, String] = options
22+
override val writeOptions: Map[String, String] = options
23+
}
1824

1925
// parses string representations of data pointers
2026
// ex: namespace.table

api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ import ai.chronon.api.URIDataPointer
55
import org.scalatest.flatspec.AnyFlatSpec
66
import org.scalatest.matchers.should.Matchers
77

8-
class
9-
DataPointerTest extends AnyFlatSpec with Matchers {
8+
class DataPointerTest extends AnyFlatSpec with Matchers {
109

1110
"DataPointer.apply" should "parse a simple s3 path" in {
1211
val result = DataPointer("s3://bucket/path/to/data.parquet")

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,21 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
3232
override def readFormat(tableName: String): Format = format(tableName)
3333

3434
override def writeFormat(table: String): Format = {
35+
val tableId = BigQueryUtil.parseTableId(table)
36+
assert(scala.Option(tableId.getProject).isDefined, s"project required for ${table}")
37+
assert(scala.Option(tableId.getDataset).isDefined, s"dataset required for ${table}")
3538

3639
val tu = TableUtils(sparkSession)
40+
val partitionColumnOption =
41+
if (tu.tableReachable(table)) Map.empty else Map("partitionField" -> tu.partitionColumn)
3742

3843
val sparkOptions: Map[String, String] = Map(
39-
"partitionField" -> tu.partitionColumn,
4044
// todo(tchow): No longer needed after https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/1320
4145
"temporaryGcsBucket" -> sparkSession.conf.get("spark.chronon.table.gcs.temporary_gcs_bucket"),
4246
"writeMethod" -> "indirect"
43-
)
47+
) ++ partitionColumnOption
4448

45-
BigQueryFormat(bqOptions.getProjectId, sparkOptions)
49+
BigQueryFormat(tableId.getProject, sparkOptions)
4650
}
4751

4852
private def getFormat(table: Table): Format =
@@ -72,7 +76,10 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
7276
val table = bigQueryClient.getTable(btTableIdentifier.getDataset, btTableIdentifier.getTable)
7377

7478
// lookup bq for the table, if not fall back to hive
75-
scala.Option(table).map(getFormat).getOrElse(Hive)
79+
scala
80+
.Option(table)
81+
.map(getFormat)
82+
.getOrElse(scala.Option(btTableIdentifier.getProject).map(BigQueryFormat(_, Map.empty)).getOrElse(Hive))
7683

7784
}
7885
}

spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For
1010
formatProvider.resolveTableName(inputTableOrPath)
1111
}
1212

13-
override lazy val options: Map[String, String] = {
14-
// Hack for now, include both read and write options for the datapointer.
15-
// todo(tchow): rework this abstraction. https://app.asana.com/0/1208785567265389/1209026103291854/f
16-
formatProvider.readFormat(inputTableOrPath).options ++ formatProvider.writeFormat(inputTableOrPath).options
13+
override lazy val readOptions: Map[String, String] = {
14+
formatProvider.readFormat(inputTableOrPath).options
15+
}
16+
17+
override lazy val writeOptions: Map[String, String] = {
18+
formatProvider.writeFormat(inputTableOrPath).options
1719
}
1820

1921
override lazy val readFormat: Option[String] = {
@@ -28,7 +30,7 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For
2830

2931
object DataPointer {
3032

31-
def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
33+
def from(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
3234

3335
CatalogAwareDataPointer(tableOrPath, FormatProvider.from(sparkSession))
3436

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -304,28 +304,26 @@ object Extensions {
304304

305305
def save(dataPointer: DataPointer): Unit = {
306306

307+
val optionDfw = dfw.options(dataPointer.writeOptions)
307308
dataPointer.writeFormat
308309
.map((wf) => {
309310
val normalized = wf.toLowerCase
310311
normalized match {
311312
case "bigquery" | "bq" =>
312-
dfw
313+
optionDfw
313314
.format("bigquery")
314-
.options(dataPointer.options)
315315
.save(dataPointer.tableOrPath)
316316
case "snowflake" | "sf" =>
317-
dfw
317+
optionDfw
318318
.format("net.snowflake.spark.snowflake")
319-
.options(dataPointer.options)
320319
.option("dbtable", dataPointer.tableOrPath)
321320
.save()
322321
case "parquet" | "csv" =>
323-
dfw
322+
optionDfw
324323
.format(normalized)
325-
.options(dataPointer.options)
326324
.save(dataPointer.tableOrPath)
327325
case "hive" | "delta" | "iceberg" =>
328-
dfw
326+
optionDfw
329327
.format(normalized)
330328
.insertInto(dataPointer.tableOrPath)
331329
case _ =>
@@ -334,7 +332,7 @@ object Extensions {
334332
})
335333
.getOrElse(
336334
// None case is just table against default catalog
337-
dfw
335+
optionDfw
338336
.format("hive")
339337
.insertInto(dataPointer.tableOrPath))
340338
}
@@ -345,29 +343,28 @@ object Extensions {
345343
def load(dataPointer: DataPointer): DataFrame = {
346344
val tableOrPath = dataPointer.tableOrPath
347345

346+
val optionDfr = dfr.options(dataPointer.readOptions)
347+
348348
dataPointer.readFormat
349349
.map { fmt =>
350350
val fmtLower = fmt.toLowerCase
351351

352352
fmtLower match {
353353

354354
case "bigquery" | "bq" =>
355-
dfr
355+
optionDfr
356356
.format("bigquery")
357-
.options(dataPointer.options)
358357
.load(tableOrPath)
359358

360359
case "snowflake" | "sf" =>
361-
dfr
360+
optionDfr
362361
.format("net.snowflake.spark.snowflake")
363-
.options(dataPointer.options)
364362
.option("dbtable", tableOrPath)
365363
.load()
366364

367365
case "parquet" | "csv" =>
368-
dfr
369-
.format(fmtLower)
370-
.options(dataPointer.options)
366+
optionDfr
367+
.format(fmt)
371368
.load(tableOrPath)
372369

373370
case "hive" | "delta" | "iceberg" => dfr.table(tableOrPath)
@@ -379,7 +376,7 @@ object Extensions {
379376
}
380377
.getOrElse {
381378
// None case is just table against default catalog
382-
dfr.table(tableOrPath)
379+
optionDfr.table(tableOrPath)
383380
}
384381
}
385382
}

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
125125
true
126126
} catch {
127127
case ex: Exception =>
128-
logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
128+
logger.debug(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
129129
|Call path:
130130
|${cleanStackTrace(ex).yellow}
131131
|""".stripMargin)
@@ -135,7 +135,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
135135

136136
// Needs provider
137137
def loadTable(tableName: String): DataFrame = {
138-
sparkSession.read.load(DataPointer(tableName, sparkSession))
138+
sparkSession.read.load(DataPointer.from(tableName, sparkSession))
139139
}
140140

141141
def isPartitioned(tableName: String): Boolean = {
@@ -241,7 +241,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
241241
}
242242

243243
def getSchemaFromTable(tableName: String): StructType = {
244-
sparkSession.read.load(DataPointer(tableName, sparkSession)).limit(1).schema
244+
sparkSession.read.load(DataPointer.from(tableName, sparkSession)).limit(1).schema
245245
}
246246

247247
// method to check if a user has access to a table
@@ -254,7 +254,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
254254
// retrieve one row from the table
255255
val partitionFilter = lastAvailablePartition(tableName).getOrElse(fallbackPartition)
256256
sparkSession.read
257-
.load(DataPointer(tableName, sparkSession))
257+
.load(DataPointer.from(tableName, sparkSession))
258258
.where(s"$partitionColumn='$partitionFilter'")
259259
.limit(1)
260260
.collect()
@@ -545,8 +545,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
545545
(Seq(partitionColumn, saltCol), Seq(partitionColumn) ++ sortByCols)
546546
} else { (Seq(saltCol), sortByCols) }
547547
logger.info(s"Sorting within partitions with cols: $partitionSortCols")
548+
val dataPointer = DataPointer.from(tableName, sparkSession)
548549

549-
val dataPointer = DataPointer(tableName, sparkSession)
550550
saltedDf
551551
.select(saltedDf.columns.map {
552552
case c if c == partitionColumn && dataPointer.writeFormat.map(_.toUpperCase).exists("BIGQUERY".equals) =>
@@ -763,14 +763,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
763763
wheres: Seq[String],
764764
rangeWheres: Seq[String],
765765
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {
766-
767-
val dp = DataPointer(table, sparkSession)
766+
val dp = DataPointer.from(table, sparkSession)
768767
var df = sparkSession.read.load(dp)
769768
val selects = QueryUtils.buildSelects(selectMap, fallbackSelects)
770769

771770
logger.info(s""" Scanning data:
772771
| table: ${dp.tableOrPath.green}
773-
| options: ${dp.options}
772+
| options: ${dp.readOptions}
774773
| format: ${dp.readFormat}
775774
| selects:
776775
| ${selects.mkString("\n ").green}

0 commit comments

Comments
 (0)