Skip to content

feat: do partition filtering on bq native tables by union individual partitions #690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ai.chronon.spark.catalog.Format
import com.google.cloud.bigquery.BigQueryOptions
import com.google.cloud.spark.bigquery.v2.Spark35BigQueryTableProvider
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.{col, date_format, to_date}
import org.apache.spark.sql.functions.{col, date_format, to_date, lit}

case object BigQueryNative extends Format {

Expand All @@ -16,56 +16,61 @@ case object BigQueryNative extends Format {

override def table(tableName: String, partitionFilters: String)(implicit sparkSession: SparkSession): DataFrame = {
import sparkSession.implicits._

// First, need to clean the spark-based table name for the bigquery queries below.
val bqTableId = SparkBQUtils.toTableId(tableName)
val providedProject = scala.Option(bqTableId.getProject).getOrElse(bqOptions.getProjectId)
val bqFriendlyName = f"${providedProject}.${bqTableId.getDataset}.${bqTableId.getTable}"

// Then, we query the BQ information schema to grab the table's partition column.
val partColsSql =
s"""
|SELECT column_name, IS_SYSTEM_DEFINED FROM `${providedProject}.${bqTableId.getDataset}.INFORMATION_SCHEMA.COLUMNS`
|SELECT column_name FROM `${providedProject}.${bqTableId.getDataset}.INFORMATION_SCHEMA.COLUMNS`
|WHERE table_name = '${bqTableId.getTable}' AND is_partitioning_column = 'YES'
|
|""".stripMargin

val (partColName, systemDefined) = sparkSession.read
val partColName = sparkSession.read
.format(bqFormat)
.option("project", providedProject)
// See: https://github.com/GoogleCloudDataproc/spark-bigquery-connector/issues/434#issuecomment-886156191
// and: https://cloud.google.com/bigquery/docs/information-schema-intro#limitations
.option("viewsEnabled", true)
.option("materializationDataset", bqTableId.getDataset)
.load(partColsSql)
.as[(String, String)]
.as[String]
.collect
.headOption
.getOrElse(throw new UnsupportedOperationException(s"No partition column for table ${tableName} found."))

val isPseudoColumn = systemDefined match {
case "YES" => true
case "NO" => false
case _ => throw new IllegalArgumentException(s"Unknown partition column system definition: ${systemDefined}")
}

logger.info(
s"Found bigquery partition column: ${partColName} with system defined status: ${systemDefined} for table: ${tableName}")
.getOrElse(
throw new UnsupportedOperationException(s"No partition column for table ${tableName} found.")
) // TODO: support unpartitioned tables (uncommon case).

// Next, we query the BQ table using the requested partitionFilter to grab all the distinct partition values that match the filter.
val partitionWheres = if (partitionFilters.nonEmpty) s"WHERE ${partitionFilters}" else partitionFilters
val partitionFormat = TableUtils(sparkSession).partitionFormat
val dfw = sparkSession.read
val select = s"SELECT distinct(${partColName}) AS ${internalBQCol} FROM ${bqFriendlyName} ${partitionWheres}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use DISTINCT …distinct() is invalid BigQuery SQL

distinct(${partColName}) will ☠️ on BQ.
Replace with the keyword form.

-val select = s"SELECT distinct(${partColName}) AS ${internalBQCol} FROM ${bqFriendlyName} ${partitionWheres}"
+val select = s"SELECT DISTINCT ${partColName} AS ${internalBQCol} FROM ${bqFriendlyName} ${partitionWheres}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val select = s"SELECT distinct(${partColName}) AS ${internalBQCol} FROM ${bqFriendlyName} ${partitionWheres}"
val select = s"SELECT DISTINCT ${partColName} AS ${internalBQCol} FROM ${bqFriendlyName} ${partitionWheres}"

val selectedParts = sparkSession.read
.format(bqFormat)
.option("viewsEnabled", true)
.option("materializationDataset", bqTableId.getDataset)
if (isPseudoColumn) {
val select = s"SELECT ${partColName} AS ${internalBQCol}, * FROM ${bqFriendlyName} ${partitionWheres}"
logger.info(s"BQ select: ${select}")
dfw
.load(select)
.withColumn(partColName, date_format(col(internalBQCol), partitionFormat))
.drop(internalBQCol)
} else {
dfw
.load(s"SELECT * FROM ${bqFriendlyName} ${partitionWheres}")
}
.load(select)
.select(date_format(col(internalBQCol), partitionFormat))
Comment on lines +52 to +57
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing project option when reading partitions

Reads here omit .option("project", providedProject). Cross-project tables will fail.

 val selectedParts = sparkSession.read
   .format(bqFormat)
+  .option("project", providedProject)
   .option("viewsEnabled", true)
   .option("materializationDataset", bqTableId.getDataset)
   .load(select)

Apply the same when loading each partition below.

.as[String]
.collect
.toList
logger.info(s"Part values: ${selectedParts}")

// Finally, we query the BQ table for each of the selected partition values and union them together.
selectedParts
.map((partValue) => {
val pFilter = f"${partColName} = '${partValue}'"
sparkSession.read
.format(bqFormat)
.option("filter", pFilter)
.load(bqFriendlyName)
.withColumn(partColName, lit(partValue))
}) // todo: make it nullable
.reduce(_ unionByName _)
Comment on lines +64 to +73
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard against empty partition list

reduce throws on Nil.

-  .reduce(_ unionByName _)
+  .reduceOption(_ unionByName _)
+  .getOrElse(sparkSession.emptyDataFrame)

}

override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
Expand Down