Skip to content

Commit 883c9a8

Browse files
authored
Merge branch 'main' into tchow/bq-support-7
2 parents 8256ef4 + 723c69c commit 883c9a8

File tree

5 files changed

+126
-7
lines changed

5 files changed

+126
-7
lines changed

build.sbt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ lazy val flink = project
206206

207207
// GCP requires java 11, can't cross compile higher
208208
lazy val cloud_gcp = project
209-
.dependsOn(api.%("compile->compile;test->test"), online, spark)
209+
.dependsOn(api % ("compile->compile;test->test"), online, spark % ("compile->compile;test->test"))
210210
.settings(
211211
libraryDependencies += "com.google.cloud" % "google-cloud-bigquery" % "2.42.0",
212212
libraryDependencies += "com.google.cloud" % "google-cloud-bigtable" % "2.41.0",
@@ -217,7 +217,8 @@ lazy val cloud_gcp = project
217217
libraryDependencies += "com.google.cloud.spark" %% s"spark-bigquery-with-dependencies" % "0.41.0",
218218
libraryDependencies ++= circe,
219219
libraryDependencies ++= avro,
220-
libraryDependencies ++= spark_all
220+
libraryDependencies ++= spark_all_provided,
221+
dependencyOverrides ++= jackson
221222
)
222223

223224
lazy val cloud_aws = project
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package ai.chronon.integrations.cloud_gcp
2+
3+
import ai.chronon.spark.Format
4+
import ai.chronon.spark.FormatProvider
5+
import ai.chronon.spark.Hive
6+
import com.google.cloud.bigquery.connector.common.BigQueryUtil
7+
import org.apache.spark.sql.SparkSession
8+
9+
case class GCPFormatProvider(sparkSession: SparkSession) extends FormatProvider {
10+
def readFormat(tableName: String): Format = {
11+
12+
val tableIdentifier = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
13+
val tableMeta = sparkSession.sessionState.catalog.getTableRawMetadata(tableIdentifier)
14+
15+
val storageProvider = tableMeta.provider
16+
storageProvider match {
17+
case Some("com.google.cloud.spark.bigquery") => {
18+
19+
val tableProperties = tableMeta.properties
20+
val project = tableProperties
21+
.get("FEDERATION_BIGQUERY_TABLE_PROPERTY")
22+
.map(BigQueryUtil.parseTableId)
23+
.map(_.getProject)
24+
.getOrElse(throw new IllegalStateException("bigquery project required!"))
25+
26+
val bigQueryTableType = tableProperties.get("federation.bigquery.table.type")
27+
bigQueryTableType.map(_.toUpperCase) match {
28+
case Some("EXTERNAL") => throw new IllegalStateException("External tables not yet supported.")
29+
case Some("MANAGED") => BQuery(project)
30+
case None => throw new IllegalStateException("Dataproc federation service must be available.")
31+
}
32+
}
33+
34+
case Some("hive") | None => Hive
35+
}
36+
37+
}
38+
39+
// For now, fix to BigQuery. We'll clean this up.
40+
def writeFormat(tableName: String): Format = ???
41+
}
42+
43+
case class BQuery(project: String) extends Format {
44+
45+
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
46+
implicit sparkSession: SparkSession): Seq[String] =
47+
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
48+
49+
override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = {
50+
import sparkSession.implicits._
51+
val tableIdentifier = BigQueryUtil.parseTableId(tableName)
52+
val table = tableIdentifier.getTable
53+
val database =
54+
Option(tableIdentifier.getDataset).getOrElse(throw new IllegalArgumentException("database required!"))
55+
56+
val originalViewsEnabled = sparkSession.conf.get("viewsEnabled", false.toString)
57+
val originalMaterializationDataset = sparkSession.conf.get("materializationDataset", "")
58+
59+
// See: https://github.com/GoogleCloudDataproc/spark-bigquery-connector/issues/434#issuecomment-886156191
60+
// and: https://cloud.google.com/bigquery/docs/information-schema-intro#limitations
61+
62+
sparkSession.conf.set("viewsEnabled", true)
63+
sparkSession.conf.set("materializationDataset", database)
64+
65+
try {
66+
// See: https://cloud.google.com/bigquery/docs/information-schema-columns
67+
val partColsSql =
68+
s"""
69+
|SELECT column_name FROM `${project}.${database}.INFORMATION_SCHEMA.COLUMNS`
70+
|WHERE table_name = '${table}' AND is_partitioning_column = 'YES'
71+
|
72+
|""".stripMargin
73+
74+
val partitionCol = sparkSession.read
75+
.format("bigquery")
76+
.option("project", project)
77+
.option("query", partColsSql)
78+
.load()
79+
.as[String]
80+
.collect
81+
.headOption
82+
.getOrElse(throw new UnsupportedOperationException(s"No partition column for table ${tableName} found."))
83+
84+
// See: https://cloud.google.com/bigquery/docs/information-schema-partitions
85+
val partValsSql =
86+
s"""
87+
|SELECT partition_id FROM `${project}.${database}.INFORMATION_SCHEMA.PARTITIONS`
88+
|WHERE table_name = '${table}'
89+
|
90+
|""".stripMargin
91+
92+
val partitionVals = sparkSession.read
93+
.format("bigquery")
94+
.option("project", project)
95+
.option("query", partValsSql)
96+
.load()
97+
.as[String]
98+
.collect
99+
.toList
100+
partitionVals.map((p) => Map(partitionCol -> p))
101+
102+
} finally {
103+
sparkSession.conf.set("viewsEnabled", originalViewsEnabled)
104+
sparkSession.conf.set("materializationDataset", originalMaterializationDataset)
105+
}
106+
107+
}
108+
109+
def createTableTypeString: String = "BIGQUERY"
110+
def fileFormatString(format: String): String = ""
111+
112+
override def supportSubPartitionsFilter: Boolean = true
113+
}

spark/src/main/scala/ai/chronon/spark/GroupBy.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,8 @@ object GroupBy {
677677
tableUtils.scanDfBase(
678678
selects,
679679
if (mutations) source.getEntities.mutationTable.cleanSpec else source.table,
680-
Option(source.query.wheres).map(_.toScala).getOrElse(Seq.empty[String]) ++ partitionConditions,
680+
Option(source.query.wheres).map(_.toScala).getOrElse(Seq.empty[String]),
681+
partitionConditions,
681682
Some(metaColumns ++ keys.map(_ -> null))
682683
)
683684
}

spark/src/main/scala/ai/chronon/spark/Join.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,10 @@ class Join(joinConf: api.Join,
213213
} else {
214214
leftRange
215215
}
216-
val wheres = Seq(s"ds >= '${effectiveRange.start}'", s"ds <= '${effectiveRange.end}'")
216+
val wheres = effectiveRange.whereClauses("ds")
217217
val sql = QueryUtils.build(null, partTable, wheres)
218218
logger.info(s"Pulling data from joinPart table with: $sql")
219-
(joinPart, tableUtils.scanDfBase(null, partTable, wheres))
219+
(joinPart, tableUtils.scanDfBase(null, partTable, List.empty, wheres, None))
220220
}
221221
}
222222

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,8 @@ case class TableUtils(sparkSession: SparkSession) {
785785

786786
def scanDfBase(selectMap: Map[String, String],
787787
table: String,
788-
wheres: scala.collection.Seq[String],
788+
wheres: Seq[String],
789+
rangeWheres: Seq[String],
789790
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {
790791
val dp = DataPointer(table)
791792
var df = dp.toDf(sparkSession)
@@ -798,9 +799,12 @@ case class TableUtils(sparkSession: SparkSession) {
798799
| ${selects.mkString("\n ").green}
799800
| wheres:
800801
| ${wheres.mkString(",\n ").green}
802+
| partition filters:
803+
| ${rangeWheres.mkString(",\n ").green}
801804
|""".stripMargin.yellow)
802805
if (selects.nonEmpty) df = df.selectExpr(selects: _*)
803806
if (wheres.nonEmpty) df = df.where(wheres.map(w => s"($w)").mkString(" AND "))
807+
if (rangeWheres.nonEmpty) df = df.where(rangeWheres.map(w => s"($w)").mkString(" AND "))
804808
df
805809
}
806810

@@ -822,7 +826,7 @@ case class TableUtils(sparkSession: SparkSession) {
822826

823827
val selects = Option(query).flatMap(q => Option(q.selects)).map(_.toScala).getOrElse(Map.empty)
824828

825-
scanDfBase(selects, table, wheres, fallbackSelects)
829+
scanDfBase(selects, table, wheres, rangeWheres, fallbackSelects)
826830
}
827831

828832
def partitionRange(table: String): PartitionRange = {

0 commit comments

Comments
 (0)