Skip to content

Commit a849504

Browse files
feat: use spark bq connector v1 (#664)
## Summary - We need to bring back the v1 version of Datasource for spark bigquery connector, since it supports partition pushdown. And alternative project_id's. The catalog version in the spark bigquery connector does not support that. ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Enhanced table reading capabilities with support for applying partition filters and combining multiple predicates for more flexible data queries. - **Refactor** - Improved internal handling of predicate filters and table loading logic for more consistent and maintainable data access. - Refined data filtering by explicitly incorporating partition column information for more precise queries. - **Chores** - Updated script to ensure temporary files are cleaned up more reliably during installation processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> --------- Co-authored-by: Thomas Chow <[email protected]>
1 parent 5afc499 commit a849504

File tree

6 files changed

+48
-13
lines changed

6 files changed

+48
-13
lines changed

api/python/ai/chronon/resources/gcp/zipline-cli-install.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ done
4949

5050
gcloud storage cp "${ARTIFACT_PREFIX%/}/release/$VERSION/wheels/zipline_ai-$VERSION-py3-none-any.whl" .
5151

52+
trap 'rm -f ./zipline_ai-$VERSION-py3-none-any.whl' EXIT
53+
5254
pip3 uninstall zipline-ai
5355

5456
pip3 install ./zipline_ai-$VERSION-py3-none-any.whl
55-
56-
trap 'rm -f ./zipline_ai-$VERSION-py3-none-any.whl' EXIT

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryNative.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,30 @@ import ai.chronon.spark.TableUtils
44
import ai.chronon.spark.format.Format
55
import com.google.cloud.bigquery.BigQueryOptions
66
import com.google.cloud.spark.bigquery.v2.Spark35BigQueryTableProvider
7-
import org.apache.spark.sql.SparkSession
7+
import org.apache.spark.sql.{DataFrame, SparkSession}
88
import org.apache.spark.sql.functions.{col, date_format, to_date}
99

1010
case object BigQueryNative extends Format {
1111

1212
private val bqFormat = classOf[Spark35BigQueryTableProvider].getName
1313
private lazy val bqOptions = BigQueryOptions.getDefaultInstance
1414

15+
override def table(tableName: String, partitionFilters: String)(implicit sparkSession: SparkSession): DataFrame = {
16+
val bqTableId = SparkBQUtils.toTableId(tableName)
17+
val bqFriendlyName = scala.Option(bqTableId.getProject) match {
18+
case Some(project) => f"${project}.${bqTableId.getDataset}.${bqTableId.getTable}"
19+
case None => f"${bqTableId.getDataset}.${bqTableId.getTable}"
20+
}
21+
val dfw = sparkSession.read.format(bqFormat)
22+
if (partitionFilters.isEmpty) {
23+
dfw.load(bqFriendlyName)
24+
} else {
25+
dfw
26+
.option("filter", partitionFilters.trim.stripPrefix("(").stripSuffix(")"))
27+
.load(bqFriendlyName)
28+
}
29+
}
30+
1531
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
1632
implicit sparkSession: SparkSession): List[String] =
1733
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)

spark/src/main/scala/ai/chronon/spark/GroupBy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ object GroupBy {
655655
|""".stripMargin)
656656
metaColumns ++= timeMapping
657657

658-
val partitionConditions = tableUtils.whereClauses(intersectedRange)
658+
val partitionConditions = tableUtils.whereClauses(intersectedRange, source.partitionColumn(tableUtils))
659659

660660
logger.info(s"""
661661
|Rendering source query:

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,12 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
115115
}
116116
}
117117

118-
def loadTable(tableName: String): DataFrame = {
119-
sparkSession.read.table(tableName)
118+
def loadTable(tableName: String, rangeWheres: Seq[String] = List.empty[String]): DataFrame = {
119+
tableFormatProvider
120+
.readFormat(tableName)
121+
.map(_.table(tableName, andPredicates(rangeWheres))(sparkSession))
122+
.getOrElse(
123+
throw new RuntimeException(s"Could not load table: ${tableName} with partition filter: ${rangeWheres}"))
120124
}
121125

122126
def createDatabase(database: String): Boolean = {
@@ -564,6 +568,12 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
564568
}
565569
}
566570

571+
private def andPredicates(predicates: Seq[String]): String = {
572+
val whereStr = predicates.map(p => s"($p)").mkString(" AND ")
573+
logger.info(s"""Where str: $whereStr""")
574+
whereStr
575+
}
576+
567577
def scanDfBase(selectMap: Map[String, String],
568578
table: String,
569579
wheres: Seq[String],
@@ -582,14 +592,12 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
582592
| ${rangeWheres.mkString(",\n ").green}
583593
|""".stripMargin)
584594

585-
var df = loadTable(table)
595+
var df = loadTable(table, rangeWheres)
586596

587597
if (selects.nonEmpty) df = df.selectExpr(selects: _*)
588598

589-
val allWheres = wheres ++ rangeWheres
590-
if (allWheres.nonEmpty) {
591-
val whereStr = allWheres.map(w => s"($w)").mkString(" AND ")
592-
logger.info(s"""Where str: $whereStr""")
599+
if (wheres.nonEmpty) {
600+
val whereStr = andPredicates(wheres)
593601
df = df.where(whereStr)
594602
}
595603

spark/src/main/scala/ai/chronon/spark/format/Format.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,21 @@ package ai.chronon.spark.format
22

33
import org.apache.spark.sql.SparkSession
44
import org.slf4j.{Logger, LoggerFactory}
5+
import org.apache.spark.sql.DataFrame
56

67
trait Format {
8+
79
@transient protected lazy val logger: Logger = LoggerFactory.getLogger(getClass)
810

11+
def table(tableName: String, partitionFilters: String)(implicit sparkSession: SparkSession): DataFrame = {
12+
val df = sparkSession.read.table(tableName)
13+
if (partitionFilters.isEmpty) {
14+
df
15+
} else {
16+
df.where(partitionFilters)
17+
}
18+
}
19+
920
// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided
1021
// If subpartition filters are supplied and the format doesn't support it, we throw an error
1122
def primaryPartitions(tableName: String,

spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {
6969
val rdd = args.sparkSession.sparkContext.parallelize(leftData)
7070
val df = args.sparkSession.createDataFrame(rdd).toDF(columns: _*)
7171

72-
when(mockTableUtils.loadTable(any())).thenReturn(df)
72+
when(mockTableUtils.loadTable(any(), any())).thenReturn(df)
7373

7474
assertTrue(args.validateResult(df, Seq("keyId", "ds"), mockTableUtils))
7575
}
@@ -85,7 +85,7 @@ class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {
8585
val rightRdd = args.sparkSession.sparkContext.parallelize(rightData)
8686
val rightDf = args.sparkSession.createDataFrame(rightRdd).toDF(columns: _*)
8787

88-
when(mockTableUtils.loadTable(any())).thenReturn(rightDf)
88+
when(mockTableUtils.loadTable(any(), any())).thenReturn(rightDf)
8989

9090
assertFalse(args.validateResult(leftDf, Seq("keyId", "ds"), mockTableUtils))
9191
}

0 commit comments

Comments
 (0)