Skip to content

fix: properly detect bigquery catalog #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.iceberg.gcp.bigquery.BigQueryMetastoreCatalog

import java.util
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -89,6 +90,8 @@ class DelegatingBigQueryMetastoreCatalog extends TableCatalog with SupportsNames
private var catalogName: String =
null // This corresponds to `spark_catalog in `spark.sql.catalog.spark_catalog`. This is necessary for spark to correctly choose which implementation to use.

private var catalogProps: Map[String, String] = Map.empty[String, String]

override def listNamespaces: Array[Array[String]] = icebergCatalog.listNamespaces()

override def listNamespaces(namespace: Array[String]): Array[Array[String]] = icebergCatalog.listNamespaces(namespace)
Expand All @@ -114,13 +117,27 @@ class DelegatingBigQueryMetastoreCatalog extends TableCatalog with SupportsNames
override def listTables(namespace: Array[String]): Array[Identifier] = icebergCatalog.listTables(namespace)

override def loadTable(rawIdent: Identifier): Table = {
val ident = Identifier.of(rawIdent.namespace.flatMap(_.split("\\.")), rawIdent.name)
Try { icebergCatalog.loadTable(ident) }
// Remove the catalog segment. We've already consumed it, now it's time to figure out the namespace.
val identNoCatalog = Identifier.of(
rawIdent.namespace.flatMap(_.split("\\.")).toList match {
case catalog :: namespace :: Nil => Array(namespace)
case namespace :: Nil => Array(namespace)
},
rawIdent.name
)
Try {
val icebergSparkTable = icebergCatalog.loadTable(identNoCatalog)
DelegatingTable(icebergSparkTable,
additionalProperties =
Map(TableCatalog.PROP_EXTERNAL -> "false", TableCatalog.PROP_PROVIDER -> "ICEBERG"))
}
.recover {
case _ => {
val tId = ident.namespace().toList match {
case database :: Nil => TableId.of(database, ident.name())
case project :: database :: Nil => TableId.of(project, database, ident.name())
val project =
catalogProps.getOrElse(BigQueryMetastoreCatalog.PROPERTIES_KEY_GCP_PROJECT, bqOptions.getProjectId)
val tId = identNoCatalog.namespace().toList match {
case database :: Nil => TableId.of(project, database, identNoCatalog.name())
case catalog :: database :: Nil => TableId.of(project, database, identNoCatalog.name())
case Nil =>
throw new IllegalArgumentException(s"Table identifier namespace ${rawIdent} must have at least one part.")
}
Expand All @@ -143,7 +160,9 @@ class DelegatingBigQueryMetastoreCatalog extends TableCatalog with SupportsNames
None,
classOf[ParquetFileFormat])
DelegatingTable(fileBasedTable,
Map(TableCatalog.PROP_EXTERNAL -> "true", TableCatalog.PROP_LOCATION -> uri))
Map(TableCatalog.PROP_EXTERNAL -> "true",
TableCatalog.PROP_LOCATION -> uri,
TableCatalog.PROP_PROVIDER -> "PARQUET"))
}
case _: StandardTableDefinition => {
//todo(tchow): Support partitioning
Expand All @@ -153,13 +172,14 @@ class DelegatingBigQueryMetastoreCatalog extends TableCatalog with SupportsNames
val connectorTable = connectorCatalog.loadTable(Identifier.of(Array(tId.getDataset), tId.getTable))
// ideally it should be the below:
// val connectorTable = connectorCatalog.loadTable(ident)
DelegatingTable(connectorTable, Map(TableCatalog.PROP_EXTERNAL -> "false"))
DelegatingTable(connectorTable,
Map(TableCatalog.PROP_EXTERNAL -> "false", TableCatalog.PROP_PROVIDER -> "BIGQUERY"))
}
case _ => throw new IllegalStateException(s"Cannot support table of type: ${table.getFriendlyName}")
}
}
}
.getOrElse(throw new NoSuchTableException(f"Tgable: ${ident} not found in bigquery catalog."))
.getOrElse(throw new NoSuchTableException(f"Table: ${identNoCatalog} not found in bigquery catalog."))
}

override def createTable(ident: Identifier,
Expand Down Expand Up @@ -187,6 +207,7 @@ class DelegatingBigQueryMetastoreCatalog extends TableCatalog with SupportsNames
icebergCatalog.initialize(name, options)
connectorCatalog.initialize(name, options)
catalogName = name
catalogProps = options.asCaseSensitiveMap.asScala.toMap
}

override def name(): String = catalogName
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package ai.chronon.integrations.cloud_gcp
import ai.chronon.spark.format.{DefaultFormatProvider, Format, Iceberg}
import com.google.cloud.bigquery._
import com.google.cloud.iceberg.bigquery.relocated.com.google.api.services.bigquery.model.TableReference
import org.apache.iceberg.exceptions.NoSuchIcebergTableException
import org.apache.iceberg.gcp.bigquery.{BigQueryClient, BigQueryClientImpl}
import org.apache.iceberg.gcp.bigquery.BigQueryMetastoreCatalog
import org.apache.iceberg.spark.SparkCatalog
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.TableCatalog

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}
import scala.util.Try

class GcpFormatProvider(override val sparkSession: SparkSession) extends DefaultFormatProvider(sparkSession) {

Expand All @@ -18,46 +18,29 @@ class GcpFormatProvider(override val sparkSession: SparkSession) extends Default
* - Active project in the gcloud CLI configuration.
* - No default project: An error will occur if no project ID is available.
*/
private lazy val bqOptions = BigQueryOptions.getDefaultInstance
private lazy val bigQueryClient: BigQuery = bqOptions.getService
private lazy val icebergClient: BigQueryClient = new BigQueryClientImpl()

override def readFormat(tableName: String): scala.Option[Format] = {
logger.info(s"Retrieving read format for table: ${tableName}")

// order is important here. we want the Hive case where we just check for table in catalog to be last
Try {
val btTableIdentifier = SparkBQUtils.toTableId(tableName)(sparkSession)
val bqTable = bigQueryClient.getTable(btTableIdentifier)
getFormat(bqTable)
} match {
case Success(format) => scala.Option(format)
case Failure(e) =>
logger.info(s"${tableName} is not a BigQuery table")
super.readFormat(tableName)
}
}

private[cloud_gcp] def getFormat(table: Table): Format = {
table.getDefinition.asInstanceOf[TableDefinition] match {
case _: ExternalTableDefinition =>
val parsedCatalog = getCatalog(tableName)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Missing getCatalog implementation.

Method is called but appears to be missing implementation.


🏁 Script executed:

#!/bin/bash
# Check for getCatalog method implementation
grep -r "def getCatalog" --include="*.scala" .

Length of output: 171


Implement or Inherit getCatalog

The call to getCatalog in GcpFormatProvider.scala (line 23) doesn’t resolve locally. Although a similar method exists in DefaultFormatProvider.scala, it isn’t automatically available here. Please either inherit from a common base that provides the implementation or add a GCP-specific getCatalog method.

val identifier = SparkBQUtils.toIdentifier(tableName)(sparkSession)
val cat = sparkSession.sessionState.catalogManager.catalog(parsedCatalog)
cat match {
case delegating: DelegatingBigQueryMetastoreCatalog =>
Try {
val tableRef = new TableReference()
.setProjectId(table.getTableId.getProject)
.setDatasetId(table.getTableId.getDataset)
.setTableId(table.getTableId.getTable)

icebergClient.getTable(tableRef) // Just try to load it. It'll fail if it's not an iceberg table.
Iceberg
}.recover {
case _: NoSuchIcebergTableException => BigQueryExternal
case e: Exception => throw e
}.get

case _: StandardTableDefinition => BigQueryNative

case _ =>
throw new IllegalStateException(s"Cannot support table of type: ${table.getFriendlyName}")
delegating
.loadTable(identifier)
.properties
.asScala
.getOrElse(TableCatalog.PROP_PROVIDER, "")
.toUpperCase match {
case "ICEBERG" => Iceberg
case "BIGQUERY" => BigQueryNative
case "PARQUET" => BigQueryExternal
case unsupported => throw new IllegalStateException(s"Unsupported provider type: ${unsupported}")
}
}.toOption
case iceberg: SparkCatalog if (iceberg.icebergCatalog().isInstanceOf[BigQueryMetastoreCatalog]) =>
scala.Option(Iceberg)
case _ => super.readFormat(tableName)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ai.chronon.integrations.cloud_gcp
import com.google.cloud.bigquery.connector.common.BigQueryUtil
import org.apache.spark.sql.SparkSession
import com.google.cloud.bigquery.TableId
import org.apache.spark.sql.connector.catalog.Identifier

object SparkBQUtils {

Expand All @@ -14,4 +15,10 @@ object SparkBQUtils {
.getOrElse(TableId.of(shadedTid.getDataset, shadedTid.getTable))
}

def toIdentifier(tableName: String)(implicit spark: SparkSession): Identifier = {
val parseIdentifier = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName).reverse
Identifier.of(parseIdentifier.tail.reverse.toArray, parseIdentifier.head)

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import com.google.cloud.hadoop.fs.gcs.{
GoogleHadoopFileSystemConfiguration,
HadoopConfigurationProperty
}
import ai.chronon.spark.format.Iceberg

import com.google.cloud.spark.bigquery.SparkBigQueryUtil
import org.apache.iceberg.gcp.bigquery.{BigQueryMetastoreCatalog => BQMSCatalog}
import org.apache.iceberg.gcp.gcs.GCSFileIO
Expand Down Expand Up @@ -113,6 +115,37 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
println(allParts)
}

it should "integration testing formats" ignore {
val externalTable = "default_iceberg.data.checkouts_parquet"
val externalFormat = FormatProvider.from(spark).readFormat(externalTable)
assertEquals(Some(BigQueryExternal), externalFormat)

val externalTableNoCat = "data.checkouts_parquet"
val externalFormatNoCat = FormatProvider.from(spark).readFormat(externalTableNoCat)
assertEquals(Some(BigQueryExternal), externalFormatNoCat)

val nativeTable = "default_iceberg.data.checkouts_native"
val nativeFormat = FormatProvider.from(spark).readFormat(nativeTable)
assertEquals(Some(BigQueryNative), nativeFormat)

val nativeTableNoCat = "data.checkouts_native"
val nativeFormatNoCat = FormatProvider.from(spark).readFormat(nativeTableNoCat)
assertEquals(Some(BigQueryNative), nativeFormatNoCat)

val icebergTable = "default_iceberg.data.quickstart_purchases_davidhan_v1_dev_davidhan"
val icebergFormat = FormatProvider.from(spark).readFormat(icebergTable)
assertEquals(Some(Iceberg), icebergFormat)

val icebergTableNoCat = "data.quickstart_purchases_davidhan_v1_dev_davidhan"
val icebergFormatNoCat = FormatProvider.from(spark).readFormat(icebergTableNoCat)
assertEquals(Some(Iceberg), icebergFormatNoCat)

val dneTable = "default_iceberg.data.dne"
val dneFormat = FormatProvider.from(spark).readFormat(dneTable)
assertTrue(dneFormat.isEmpty)
}


it should "integration testing bigquery partitions" ignore {
// TODO(tchow): This test is ignored because it requires a running instance of the bigquery. Need to figure out stubbing locally.
// to run, set `GOOGLE_APPLICATION_CREDENTIALS=<path_to_application_default_credentials.json>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ class GcpFormatProviderTest extends AnyFlatSpec with MockitoSugar {
.build())
when(mockTable.getTableId).thenReturn(TableId.of("project", "dataset", "table"))

val gcsFormat = gcpFormatProvider.getFormat(mockTable)
val gcsFormat = gcpFormatProvider.readFormat(tableName)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Method changed but test still ignored.

Test was updated to use readFormat instead of getFormat but remains ignored and lacks assertions.

-val gcsFormat = gcpFormatProvider.readFormat(tableName)
+val gcsFormat = gcpFormatProvider.readFormat(tableName)
+assert(gcsFormat.isDefined, "Format should be detected")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val gcsFormat = gcpFormatProvider.readFormat(tableName)
val gcsFormat = gcpFormatProvider.readFormat(tableName)
assert(gcsFormat.isDefined, "Format should be detected")

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ class DefaultFormatProvider(val sparkSession: SparkSession) extends FormatProvid
} else { null })
}

def getCatalog(tableName: String): String = {
logger.info(s"Retrieving read format for table: ${tableName}")
val parsed = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName)
val parsedCatalog = parsed.toList match {
case catalog :: namespace :: tableName :: Nil => catalog
case namespace :: tableName :: Nil => sparkSession.catalog.currentCatalog()
case tableName :: Nil => sparkSession.catalog.currentCatalog()
case _ => throw new IllegalStateException(s"Invalid table naming convention specified: ${tableName}")
}
parsedCatalog
}

private def isIcebergTable(tableName: String): Boolean =
Try {
sparkSession.read.format("iceberg").load(tableName)
Expand Down
16 changes: 16 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ import org.apache.spark.sql.functions.col
import org.junit.Assert.{assertEquals, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec

import ai.chronon.spark.format.FormatProvider
import ai.chronon.spark.format.DefaultFormatProvider
import org.apache.spark.sql.catalyst.parser.ParseException

import scala.util.Try

case class TestRecord(ds: String, id: String)
Expand All @@ -36,6 +40,7 @@ class SimpleAddUDF extends UDF {
}

class TableUtilsTest extends AnyFlatSpec {

lazy val spark: SparkSession = SparkSessionBuilder.build("TableUtilsTest", local = true)
private val tableUtils = TableTestUtils(spark)
private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
Expand Down Expand Up @@ -639,4 +644,15 @@ class TableUtilsTest extends AnyFlatSpec {
}
}

it should "test catalog detection" in {
val fp = FormatProvider.from(spark).asInstanceOf[DefaultFormatProvider]
assertEquals("catalogA", fp.getCatalog("catalogA.foo.bar"))
assertEquals("catalogA", fp.getCatalog("`catalogA`.foo.bar"))
assertEquals("spark_catalog", fp.getCatalog("`catalogA.foo`.bar"))
assertEquals("spark_catalog", fp.getCatalog("`catalogA.foo.bar`"))
assertEquals("spark_catalog", fp.getCatalog("foo.bar"))
assertEquals("spark_catalog", fp.getCatalog("bar"))
assertThrows[ParseException](fp.getCatalog(""))
}

}