Skip to content

Commit eb08cb1

Browse files
fix: properly detect bigquery catalog
Co-authored-by: Thomas Chow <[email protected]>
1 parent 6547bf0 commit eb08cb1

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package ai.chronon.integrations.cloud_gcp
22
import ai.chronon.spark.format.{DefaultFormatProvider, Format, Iceberg}
33
import com.google.cloud.bigquery._
44
import com.google.cloud.iceberg.bigquery.relocated.com.google.api.services.bigquery.model.TableReference
5+
import com.google.cloud.spark.bigquery.BigQueryCatalog
56
import org.apache.iceberg.exceptions.NoSuchIcebergTableException
6-
import org.apache.iceberg.gcp.bigquery.{BigQueryClient, BigQueryClientImpl}
7+
import org.apache.iceberg.gcp.bigquery.{BigQueryClient, BigQueryClientImpl, BigQueryMetastoreCatalog}
8+
import org.apache.iceberg.spark.SparkCatalog
79
import org.apache.spark.sql.SparkSession
810

911
import scala.jdk.CollectionConverters._
@@ -23,19 +25,47 @@ class GcpFormatProvider(override val sparkSession: SparkSession) extends Default
2325
private lazy val icebergClient: BigQueryClient = new BigQueryClientImpl()
2426

2527
override def readFormat(tableName: String): scala.Option[Format] = {
26-
logger.info(s"Retrieving read format for table: ${tableName}")
28+
val parsedCatalog = getCatalog(tableName)
29+
30+
if (isBigQueryCatalog(parsedCatalog)) {
31+
logger.info(s"Detected BigQuery catalog: $parsedCatalog")
32+
Try {
33+
val btTableIdentifier = SparkBQUtils.toTableId(tableName)(sparkSession)
34+
val bqTable = bigQueryClient.getTable(btTableIdentifier)
35+
getFormat(bqTable)
36+
} match {
37+
case Success(format) => scala.Option(format)
38+
case Failure(e) =>
39+
throw new IllegalStateException(
40+
s"${tableName} belongs to bigquery catalog ${parsedCatalog} but could not be found",
41+
e)
42+
}
43+
} else {
44+
45+
logger.info(s"Detected non-BigQuery catalog: $parsedCatalog")
46+
super.readFormat(tableName)
47+
}
48+
}
2749

28-
// order is important here. we want the Hive case where we just check for table in catalog to be last
29-
Try {
30-
val btTableIdentifier = SparkBQUtils.toTableId(tableName)(sparkSession)
31-
val bqTable = bigQueryClient.getTable(btTableIdentifier)
32-
getFormat(bqTable)
33-
} match {
34-
case Success(format) => scala.Option(format)
35-
case Failure(e) =>
36-
logger.info(s"${tableName} is not a BigQuery table")
37-
super.readFormat(tableName)
50+
private def getCatalog(tableName: String): String = {
51+
logger.info(s"Retrieving read format for table: ${tableName}")
52+
val parsed = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName)
53+
val parsedCatalog = parsed match {
54+
case catalog :: namespace :: tableName :: Nil => catalog
55+
case namespace :: tableName :: Nil => sparkSession.catalog.currentCatalog()
56+
case tableName :: Nil => sparkSession.catalog.currentCatalog()
57+
case _ => throw new IllegalStateException(s"Invalid table naming convention specified: ${tableName}")
3858
}
59+
parsedCatalog
60+
}
61+
62+
private def isBigQueryCatalog(catalog: String): Boolean = {
63+
val cat = sparkSession.sessionState.catalogManager.catalog(catalog)
64+
cat.isInstanceOf[DelegatingBigQueryMetastoreCatalog] || cat
65+
.isInstanceOf[BigQueryCatalog] || (cat.isInstanceOf[SparkCatalog] && cat
66+
.asInstanceOf[SparkCatalog]
67+
.icebergCatalog()
68+
.isInstanceOf[BigQueryMetastoreCatalog])
3969
}
4070

4171
private[cloud_gcp] def getFormat(table: Table): Format = {

0 commit comments

Comments
 (0)