Skip to content

Commit e1a9aad

Browse files
committed
add unit test
1 parent 7c2290a commit e1a9aad

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package ai.chronon.integrations.cloud_gcp
22

3-
import ai.chronon.api.ScalaJavaConversions.ListOps
43
import ai.chronon.spark.TableUtils
54
import ai.chronon.spark.format.Format
65
import ai.chronon.spark.format.FormatProvider
76
import ai.chronon.spark.format.Hive
8-
import com.google.cloud.bigquery._
7+
import com.google.cloud.bigquery.BigQuery
8+
import com.google.cloud.bigquery.BigQueryOptions
9+
import com.google.cloud.bigquery.ExternalTableDefinition
10+
import com.google.cloud.bigquery.FormatOptions
11+
import com.google.cloud.bigquery.StandardTableDefinition
12+
import com.google.cloud.bigquery.Table
13+
import com.google.cloud.bigquery.TableDefinition
914
import com.google.cloud.bigquery.connector.common.BigQueryUtil
1015
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableId
1116
import org.apache.spark.sql.SparkSession
@@ -49,19 +54,21 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
4954
BigQueryFormat(tableId.getProject, sparkOptions)
5055
}
5156

52-
private def getFormat(table: Table): Format =
57+
private[cloud_gcp] def getFormat(table: Table): Format =
5358
table.getDefinition.asInstanceOf[TableDefinition] match {
5459

5560
case definition: ExternalTableDefinition =>
56-
val uris = definition.getSourceUris.toScala
57-
.map(uri => uri.stripSuffix("/*") + "/")
58-
59-
assert(uris.length == 1, s"External table ${table.getFriendlyName} can be backed by only one URI.")
60-
6161
val formatOptions = definition.getFormatOptions
6262
.asInstanceOf[FormatOptions]
63-
64-
GCS(table.getTableId.getProject, uris.head, formatOptions.getType)
63+
val externalTable = table.getDefinition.asInstanceOf[ExternalTableDefinition]
64+
val uri = Option(externalTable.getHivePartitioningOptions)
65+
.map(_.getSourceUriPrefix)
66+
.getOrElse {
67+
val uris = externalTable.getSourceUris
68+
require(uris.size == 1, s"External table ${table} can be backed by only one URI.")
69+
uris.get(0).replaceAll("/\\*\\.parquet$", "")
70+
}
71+
GCS(table.getTableId.getProject, uri, formatOptions.getType)
6572

6673
case _: StandardTableDefinition =>
6774
BigQueryFormat(table.getTableId.getProject, Map.empty)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package ai.chronon.integrations.cloud_gcp
2+
3+
import ai.chronon.spark.SparkSessionBuilder
4+
import org.apache.spark.sql.SparkSession
5+
import org.scalatest.flatspec.AnyFlatSpec
6+
import com.google.cloud.bigquery._
7+
import org.mockito.Mockito.when
8+
import org.scalatestplus.mockito.MockitoSugar
9+
10+
import java.util
11+
12+
class GcpFormatProviderTest extends AnyFlatSpec with MockitoSugar {
13+
14+
lazy val spark: SparkSession = SparkSessionBuilder.build(
15+
"GcpFormatProviderTest",
16+
local = true
17+
)
18+
19+
it should "check getFormat works for URI's that have a wildcard in between" in {
20+
val gcpFormatProvider = GcpFormatProvider(spark)
21+
val sourceUris = "gs://bucket-name/path/to/data/*.parquet"
22+
val tableName = "gs://bucket-name/path/to/data"
23+
24+
// mocking because bigquery Table doesn't have a constructor
25+
val mockTable = mock[Table]
26+
when(mockTable.getDefinition).thenReturn(
27+
ExternalTableDefinition
28+
.newBuilder("external")
29+
.setSourceUris(util.Arrays.asList(sourceUris))
30+
.setHivePartitioningOptions(HivePartitioningOptions.newBuilder().setSourceUriPrefix(tableName).build())
31+
.setFormatOptions(FormatOptions.parquet())
32+
.build())
33+
when(mockTable.getTableId).thenReturn(TableId.of("project", "dataset", "table"))
34+
35+
val gcsFormat = gcpFormatProvider.getFormat(mockTable).asInstanceOf[GCS]
36+
assert(gcsFormat.sourceUri == tableName)
37+
assert(gcsFormat.project == "project")
38+
assert(gcsFormat.fileFormat == "PARQUET")
39+
}
40+
}

0 commit comments

Comments
 (0)