Skip to content

Commit 94df38d

Browse files
authored
adding again: Retrieve source uri prefix from hive partitioning options when building the GCS format (#230)
## Summary basically this PR: https://github.com/zipline-ai/chronon/pull/204/files ## Checklist - [x] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced BigQuery API integration with more granular import statements - Added comprehensive test coverage for external table format handling - **Refactor** - Updated method visibility to improve package-level access - Refined external table processing logic - **Tests** - Introduced new test class for `GcpFormatProvider` - Added test case for URI handling with wildcard scenarios <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0bde573 commit 94df38d

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package ai.chronon.integrations.cloud_gcp
22

3-
import ai.chronon.api.ScalaJavaConversions.ListOps
43
import ai.chronon.spark.TableUtils
54
import ai.chronon.spark.format.Format
65
import ai.chronon.spark.format.FormatProvider
76
import ai.chronon.spark.format.Hive
8-
import com.google.cloud.bigquery._
7+
import com.google.cloud.bigquery.BigQuery
8+
import com.google.cloud.bigquery.BigQueryOptions
9+
import com.google.cloud.bigquery.ExternalTableDefinition
10+
import com.google.cloud.bigquery.FormatOptions
11+
import com.google.cloud.bigquery.StandardTableDefinition
12+
import com.google.cloud.bigquery.Table
13+
import com.google.cloud.bigquery.TableDefinition
914
import com.google.cloud.bigquery.connector.common.BigQueryUtil
1015
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableId
1116
import org.apache.spark.sql.SparkSession
@@ -51,19 +56,21 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
5156
BigQueryFormat(tableId.getProject, sparkOptions)
5257
}
5358

54-
private def getFormat(table: Table): Format =
59+
private[cloud_gcp] def getFormat(table: Table): Format =
5560
table.getDefinition.asInstanceOf[TableDefinition] match {
5661

5762
case definition: ExternalTableDefinition =>
58-
val uris = definition.getSourceUris.toScala
59-
.map(uri => uri.stripSuffix("/*") + "/")
60-
61-
assert(uris.length == 1, s"External table ${table.getFriendlyName} can be backed by only one URI.")
62-
6363
val formatOptions = definition.getFormatOptions
6464
.asInstanceOf[FormatOptions]
65-
66-
GCS(table.getTableId.getProject, uris.head, formatOptions.getType)
65+
val externalTable = table.getDefinition.asInstanceOf[ExternalTableDefinition]
66+
val uri = Option(externalTable.getHivePartitioningOptions)
67+
.map(_.getSourceUriPrefix)
68+
.getOrElse {
69+
val uris = externalTable.getSourceUris
70+
require(uris.size == 1, s"External table ${table} can be backed by only one URI.")
71+
uris.get(0).replaceAll("/\\*\\.parquet$", "")
72+
}
73+
GCS(table.getTableId.getProject, uri, formatOptions.getType)
6774

6875
case _: StandardTableDefinition =>
6976
BigQueryFormat(table.getTableId.getProject, Map.empty)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package ai.chronon.integrations.cloud_gcp
2+
3+
import ai.chronon.spark.SparkSessionBuilder
4+
import com.google.cloud.bigquery._
5+
import org.apache.spark.sql.SparkSession
6+
import org.mockito.Mockito.when
7+
import org.scalatest.flatspec.AnyFlatSpec
8+
import org.scalatestplus.mockito.MockitoSugar
9+
10+
import java.util
11+
12+
class GcpFormatProviderTest extends AnyFlatSpec with MockitoSugar {
13+
14+
lazy val spark: SparkSession = SparkSessionBuilder.build(
15+
"GcpFormatProviderTest",
16+
local = true
17+
)
18+
19+
it should "check getFormat works for URI's that have a wildcard in between" in {
20+
val gcpFormatProvider = GcpFormatProvider(spark)
21+
val sourceUris = "gs://bucket-name/path/to/data/*.parquet"
22+
val tableName = "gs://bucket-name/path/to/data"
23+
24+
// mocking because bigquery Table doesn't have a constructor
25+
val mockTable = mock[Table]
26+
when(mockTable.getDefinition).thenReturn(
27+
ExternalTableDefinition
28+
.newBuilder("external")
29+
.setSourceUris(util.Arrays.asList(sourceUris))
30+
.setHivePartitioningOptions(HivePartitioningOptions.newBuilder().setSourceUriPrefix(tableName).build())
31+
.setFormatOptions(FormatOptions.parquet())
32+
.build())
33+
when(mockTable.getTableId).thenReturn(TableId.of("project", "dataset", "table"))
34+
35+
val gcsFormat = gcpFormatProvider.getFormat(mockTable).asInstanceOf[GCS]
36+
assert(gcsFormat.sourceUri == tableName)
37+
assert(gcsFormat.project == "project")
38+
assert(gcsFormat.fileFormat == "PARQUET")
39+
}
40+
}

0 commit comments

Comments
 (0)