@@ -2,8 +2,10 @@ package ai.chronon.integrations.cloud_gcp
2
2
import ai .chronon .spark .format .{DefaultFormatProvider , Format , Iceberg }
3
3
import com .google .cloud .bigquery ._
4
4
import com .google .cloud .iceberg .bigquery .relocated .com .google .api .services .bigquery .model .TableReference
5
+ import com .google .cloud .spark .bigquery .BigQueryCatalog
5
6
import org .apache .iceberg .exceptions .NoSuchIcebergTableException
6
- import org .apache .iceberg .gcp .bigquery .{BigQueryClient , BigQueryClientImpl }
7
+ import org .apache .iceberg .gcp .bigquery .{BigQueryClient , BigQueryClientImpl , BigQueryMetastoreCatalog }
8
+ import org .apache .iceberg .spark .SparkCatalog
7
9
import org .apache .spark .sql .SparkSession
8
10
9
11
import scala .jdk .CollectionConverters ._
@@ -23,19 +25,47 @@ class GcpFormatProvider(override val sparkSession: SparkSession) extends Default
23
25
private lazy val icebergClient : BigQueryClient = new BigQueryClientImpl ()
24
26
25
27
override def readFormat (tableName : String ): scala.Option [Format ] = {
26
- logger.info(s " Retrieving read format for table: ${tableName}" )
28
+ val parsedCatalog = getCatalog(tableName)
29
+
30
+ if (isBigQueryCatalog(parsedCatalog)) {
31
+ logger.info(s " Detected BigQuery catalog: $parsedCatalog" )
32
+ Try {
33
+ val btTableIdentifier = SparkBQUtils .toTableId(tableName)(sparkSession)
34
+ val bqTable = bigQueryClient.getTable(btTableIdentifier)
35
+ getFormat(bqTable)
36
+ } match {
37
+ case Success (format) => scala.Option (format)
38
+ case Failure (e) =>
39
+ throw new IllegalStateException (
40
+ s " ${tableName} belongs to bigquery catalog ${parsedCatalog} but could not be found " ,
41
+ e)
42
+ }
43
+ } else {
44
+
45
+ logger.info(s " Detected non-BigQuery catalog: $parsedCatalog" )
46
+ super .readFormat(tableName)
47
+ }
48
+ }
27
49
28
- // order is important here. we want the Hive case where we just check for table in catalog to be last
29
- Try {
30
- val btTableIdentifier = SparkBQUtils .toTableId(tableName)(sparkSession)
31
- val bqTable = bigQueryClient.getTable(btTableIdentifier)
32
- getFormat(bqTable)
33
- } match {
34
- case Success (format) => scala.Option (format)
35
- case Failure (e) =>
36
- logger.info(s " ${tableName} is not a BigQuery table " )
37
- super .readFormat(tableName)
50
+ private def getCatalog (tableName : String ): String = {
51
+ logger.info(s " Retrieving read format for table: ${tableName}" )
52
+ val parsed = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName)
53
+ val parsedCatalog = parsed match {
54
+ case catalog :: namespace :: tableName :: Nil => catalog
55
+ case namespace :: tableName :: Nil => sparkSession.catalog.currentCatalog()
56
+ case tableName :: Nil => sparkSession.catalog.currentCatalog()
57
+ case _ => throw new IllegalStateException (s " Invalid table naming convention specified: ${tableName}" )
38
58
}
59
+ parsedCatalog
60
+ }
61
+
62
+ private def isBigQueryCatalog (catalog : String ): Boolean = {
63
+ val cat = sparkSession.sessionState.catalogManager.catalog(catalog)
64
+ cat.isInstanceOf [DelegatingBigQueryMetastoreCatalog ] || cat
65
+ .isInstanceOf [BigQueryCatalog ] || (cat.isInstanceOf [SparkCatalog ] && cat
66
+ .asInstanceOf [SparkCatalog ]
67
+ .icebergCatalog()
68
+ .isInstanceOf [BigQueryMetastoreCatalog ])
39
69
}
40
70
41
71
private [cloud_gcp] def getFormat (table : Table ): Format = {
0 commit comments