Skip to content

Commit fa27cf8

Browse files
authored
feat: unit tests for local iteration (#148)
## Summary ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new BigQuery client for enhanced interaction with BigQuery services. - Added functionality for managing partitions in Spark SQL tables through a new utility class. - **Bug Fixes** - Improved error handling in the database creation process. - **Tests** - Added a new test class for verifying BigQuery catalog functionality. - Updated existing test classes to utilize the new partition management utilities. - **Chores** - Cleaned up deprecated methods in the TableUtils class. - Refactored comments for clarity regarding method dependencies. <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` -->
1 parent b031ebc commit fa27cf8

File tree

8 files changed

+145
-79
lines changed

8 files changed

+145
-79
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ lazy val cloud_gcp = project
213213
libraryDependencies += "com.google.cloud" % "google-cloud-pubsub" % "1.131.0",
214214
libraryDependencies += "com.google.cloud" % "google-cloud-dataproc" % "4.51.0",
215215
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "3.0.3", // it's what's on the cluster
216-
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "hadoop3-2.2.26", // it's what's on the cluster
216+
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "hadoop3-2.2.26",
217217
libraryDependencies += "com.google.cloud.bigdataoss" % "gcsio" % "3.0.3", // need it for https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageFileSystem.java
218218
libraryDependencies += "io.circe" %% "circe-yaml" % "1.15.0",
219219
libraryDependencies += "org.mockito" % "mockito-core" % "5.12.0" % Test,

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.Tabl
1212
import org.apache.spark.sql.SparkSession
1313

1414
case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider {
15+
1516
lazy val bigQueryClient = BigQueryOptions.getDefaultInstance.getService
1617
def readFormat(tableName: String): Format = {
1718

@@ -126,6 +127,7 @@ case class BQuery(project: String) extends Format {
126127
sparkSession.conf.set("viewsEnabled", originalViewsEnabled)
127128
sparkSession.conf.set("materializationDataset", originalMaterializationDataset)
128129
}
130+
129131
}
130132

131133
def createTableTypeString: String = "BIGQUERY"
Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,66 @@
11
package ai.chronon.integrations.cloud_gcp.test
22

3+
import ai.chronon.integrations.cloud_gcp.BQuery
4+
import ai.chronon.integrations.cloud_gcp.GcpFormatProvider
5+
import ai.chronon.spark.SparkSessionBuilder
6+
import ai.chronon.spark.TableUtils
7+
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS
8+
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem
9+
import org.apache.spark.sql.SparkSession
10+
import org.junit.Assert.assertEquals
11+
import org.junit.Assert.assertTrue
312
import org.scalatest.funsuite.AnyFunSuite
413
import org.scalatestplus.mockito.MockitoSugar
514

6-
class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar {}
15+
class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar {
16+
17+
lazy val spark: SparkSession = SparkSessionBuilder.build(
18+
"BigQuerySparkTest",
19+
local = true,
20+
additionalConfig = Some(
21+
Map(
22+
"spark.chronon.table.format_provider.class" -> classOf[GcpFormatProvider].getName,
23+
"hive.metastore.uris" -> "thrift://localhost:9083",
24+
"spark.chronon.partition.column" -> "c",
25+
"spark.hadoop.fs.gs.impl" -> classOf[GoogleHadoopFileSystem].getName,
26+
"spark.hadoop.fs.AbstractFileSystem.gs.impl" -> classOf[GoogleHadoopFS].getName,
27+
"spark.hadoop.google.cloud.auth.service.account.enable" -> true.toString,
28+
"spark.hadoop.fs.gs.impl" -> classOf[GoogleHadoopFileSystem].getName
29+
))
30+
)
31+
lazy val tableUtils: TableUtils = TableUtils(spark)
32+
33+
test("hive uris are set") {
34+
assertEquals("thrift://localhost:9083", spark.sqlContext.getConf("hive.metastore.uris"))
35+
}
36+
37+
test("verify dynamic classloading of GCP providers") {
38+
assertTrue(tableUtils.tableReadFormat("data.sample_native") match {
39+
case BQuery(_) => true
40+
case _ => false
41+
})
42+
}
43+
44+
ignore("integration testing bigquery load table") {
45+
val externalTable = "data.checkouts_parquet"
46+
val table = tableUtils.loadTable(externalTable)
47+
tableUtils.isPartitioned(externalTable)
48+
tableUtils.createDatabase("test_database")
49+
tableUtils.allPartitions(externalTable)
50+
table.show
51+
}
52+
53+
ignore("integration testing bigquery partitions") {
54+
// TODO(tchow): This test is ignored because it requires a running instance of the bigquery. Need to figure out stubbing locally.
55+
// to run this:
56+
// 1. Set up a tunnel to dataproc federation proxy:
57+
// gcloud compute ssh zipline-canary-cluster-m \
58+
// --zone us-central1-c \
59+
// -- -f -N -L 9083:localhost:9083
60+
// 2. enable this test and off you go.
61+
val externalPartitions = tableUtils.partitions("data.checkouts_parquet")
62+
println(externalPartitions)
63+
val nativePartitions = tableUtils.partitions("data.sample_native")
64+
println(nativePartitions)
65+
}
66+
}

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 30 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ import scala.util.Try
6363
* retrieve metadata / configure it appropriately at creation time
6464
*/
6565

66-
case class TableUtils(sparkSession: SparkSession) {
66+
class TableUtils(@transient val sparkSession: SparkSession) extends Serializable {
6767
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
6868

6969
private val ARCHIVE_TIMESTAMP_FORMAT = "yyyyMMddHHmmss"
@@ -141,16 +141,24 @@ case class TableUtils(sparkSession: SparkSession) {
141141
rdd
142142
}
143143

144-
def tableExists(tableName: String): Boolean = sparkSession.catalog.tableExists(tableName)
144+
// Needs provider
145+
def tableExists(tableName: String): Boolean = {
146+
sparkSession.catalog.tableExists(tableName)
147+
}
145148

146-
def loadTable(tableName: String): DataFrame = sparkSession.table(tableName)
149+
// Needs provider
150+
def loadTable(tableName: String): DataFrame = {
151+
sparkSession.table(tableName)
152+
}
147153

154+
// Needs provider
148155
def isPartitioned(tableName: String): Boolean = {
149156
// TODO: use proper way to detect if a table is partitioned or not
150157
val schema = getSchemaFromTable(tableName)
151158
schema.fieldNames.contains(partitionColumn)
152159
}
153160

161+
// Needs provider
154162
def createDatabase(database: String): Boolean = {
155163
try {
156164
val command = s"CREATE DATABASE IF NOT EXISTS $database"
@@ -168,6 +176,7 @@ case class TableUtils(sparkSession: SparkSession) {
168176

169177
def tableReadFormat(tableName: String): Format = tableFormatProvider.readFormat(tableName)
170178

179+
// Needs provider
171180
// return all specified partition columns in a table in format of Map[partitionName, PartitionValue]
172181
def allPartitions(tableName: String, partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = {
173182
if (!tableExists(tableName)) return Seq.empty[Map[String, String]]
@@ -182,6 +191,7 @@ case class TableUtils(sparkSession: SparkSession) {
182191
}
183192
}
184193

194+
// Needs provider
185195
def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
186196
if (!tableExists(tableName)) return Seq.empty[String]
187197
val format = tableReadFormat(tableName)
@@ -222,11 +232,13 @@ case class TableUtils(sparkSession: SparkSession) {
222232
}
223233
}
224234

235+
// Needs provider
225236
def getSchemaFromTable(tableName: String): StructType = {
226237
sparkSession.sql(s"SELECT * FROM $tableName LIMIT 1").schema
227238
}
228239

229240
// method to check if a user has access to a table
241+
// Needs provider
230242
def checkTablePermission(tableName: String,
231243
fallbackPartition: String =
232244
partitionSpec.before(partitionSpec.at(System.currentTimeMillis()))): Boolean = {
@@ -252,12 +264,15 @@ case class TableUtils(sparkSession: SparkSession) {
252264
}
253265
}
254266

267+
// Needs provider
255268
def lastAvailablePartition(tableName: String, subPartitionFilters: Map[String, String] = Map.empty): Option[String] =
256269
partitions(tableName, subPartitionFilters).reduceOption((x, y) => Ordering[String].max(x, y))
257270

271+
// Needs provider
258272
def firstAvailablePartition(tableName: String, subPartitionFilters: Map[String, String] = Map.empty): Option[String] =
259273
partitions(tableName, subPartitionFilters).reduceOption((x, y) => Ordering[String].min(x, y))
260274

275+
// Needs provider
261276
def insertPartitions(df: DataFrame,
262277
tableName: String,
263278
tableProperties: Map[String, String] = null,
@@ -351,6 +366,7 @@ case class TableUtils(sparkSession: SparkSession) {
351366
}
352367
}
353368

369+
// Needs provider
354370
def insertUnPartitioned(df: DataFrame,
355371
tableName: String,
356372
tableProperties: Map[String, String] = null,
@@ -412,6 +428,7 @@ case class TableUtils(sparkSession: SparkSession) {
412428
}.get
413429
}
414430

431+
// Needs provider
415432
private def repartitionAndWriteInternal(df: DataFrame,
416433
tableName: String,
417434
saveMode: SaveMode,
@@ -488,6 +505,7 @@ case class TableUtils(sparkSession: SparkSession) {
488505
}
489506
}
490507

508+
// Needs provider
491509
private def createTableSql(tableName: String,
492510
schema: StructType,
493511
partitionColumns: Seq[String],
@@ -526,6 +544,7 @@ case class TableUtils(sparkSession: SparkSession) {
526544
Seq(createFragment, partitionFragment, fileFormatString, propertiesFragment).mkString("\n")
527545
}
528546

547+
// Needs provider
529548
private def alterTablePropertiesSql(tableName: String, properties: Map[String, String]): String = {
530549
// Only SQL api exists for setting TBLPROPERTIES
531550
val propertiesString = properties
@@ -612,6 +631,7 @@ case class TableUtils(sparkSession: SparkSession) {
612631
Some(missingChunks)
613632
}
614633

634+
// Needs provider
615635
def getTableProperties(tableName: String): Option[Map[String, String]] = {
616636
try {
617637
val tableId = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
@@ -621,6 +641,7 @@ case class TableUtils(sparkSession: SparkSession) {
621641
}
622642
}
623643

644+
// Needs provider
624645
def dropTableIfExists(tableName: String): Unit = {
625646
val command = s"DROP TABLE IF EXISTS $tableName"
626647
logger.info(s"Dropping table with command: $command")
@@ -648,68 +669,6 @@ case class TableUtils(sparkSession: SparkSession) {
648669
}
649670
}
650671

651-
@deprecated
652-
def dropPartitionsAfterHole(inputTable: String,
653-
outputTable: String,
654-
partitionRange: PartitionRange,
655-
subPartitionFilters: Map[String, String] = Map.empty): Option[String] = {
656-
657-
def partitionsInRange(table: String, partitionFilter: Map[String, String] = Map.empty): Set[String] = {
658-
val allParts = partitions(table, partitionFilter)
659-
val startPrunedParts = Option(partitionRange.start).map(start => allParts.filter(_ >= start)).getOrElse(allParts)
660-
Option(partitionRange.end).map(end => startPrunedParts.filter(_ <= end)).getOrElse(startPrunedParts).toSet
661-
}
662-
663-
val inputPartitions = partitionsInRange(inputTable)
664-
val outputPartitions = partitionsInRange(outputTable, subPartitionFilters)
665-
val earliestHoleOpt = (inputPartitions -- outputPartitions).reduceLeftOption(Ordering[String].min)
666-
earliestHoleOpt.foreach { hole =>
667-
val toDrop = outputPartitions.filter(_ > hole)
668-
logger.info(s"""
669-
|Earliest hole at $hole in output table $outputTable, relative to $inputTable
670-
|Input Parts : ${inputPartitions.toArray.sorted.mkString("Array(", ", ", ")")}
671-
|Output Parts : ${outputPartitions.toArray.sorted.mkString("Array(", ", ", ")")}
672-
|Dropping Parts: ${toDrop.toArray.sorted.mkString("Array(", ", ", ")")}
673-
|Sub Partitions: ${subPartitionFilters.map(kv => s"${kv._1}=${kv._2}").mkString("Array(", ", ", ")")}
674-
""".stripMargin)
675-
dropPartitions(outputTable, toDrop.toArray.sorted, partitionColumn, subPartitionFilters)
676-
}
677-
earliestHoleOpt
678-
}
679-
680-
def dropPartitions(tableName: String,
681-
partitions: Seq[String],
682-
partitionColumn: String = partitionColumn,
683-
subPartitionFilters: Map[String, String] = Map.empty): Unit = {
684-
if (partitions.nonEmpty && tableExists(tableName)) {
685-
val partitionSpecs = partitions
686-
.map { partition =>
687-
val mainSpec = s"$partitionColumn='$partition'"
688-
val specs = mainSpec +: subPartitionFilters.map {
689-
case (key, value) => s"$key='$value'"
690-
}.toSeq
691-
specs.mkString("PARTITION (", ",", ")")
692-
}
693-
.mkString(",")
694-
val dropSql = s"ALTER TABLE $tableName DROP IF EXISTS $partitionSpecs"
695-
sql(dropSql)
696-
} else {
697-
logger.info(s"$tableName doesn't exist, please double check before drop partitions")
698-
}
699-
}
700-
701-
def dropPartitionRange(tableName: String,
702-
startDate: String,
703-
endDate: String,
704-
subPartitionFilters: Map[String, String] = Map.empty): Unit = {
705-
if (tableExists(tableName)) {
706-
val toDrop = Stream.iterate(startDate)(partitionSpec.after).takeWhile(_ <= endDate)
707-
dropPartitions(tableName, toDrop, partitionColumn, subPartitionFilters)
708-
} else {
709-
logger.info(s"$tableName doesn't exist, please double check before drop partitions")
710-
}
711-
}
712-
713672
/*
714673
* This method detects new columns that appear in newSchema but not in current table,
715674
* and append those new columns at the end of the existing table. This allows continuous evolution
@@ -837,6 +796,12 @@ case class TableUtils(sparkSession: SparkSession) {
837796
}
838797
}
839798

799+
object TableUtils {
800+
def apply(sparkSession: SparkSession): TableUtils = {
801+
new TableUtils(sparkSession)
802+
}
803+
}
804+
840805
sealed case class IncompatibleSchemaException(inconsistencies: Seq[(String, DataType, DataType)]) extends Exception {
841806
override def getMessage: String = {
842807
val inconsistenciesStr =

spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ object TestRow {
5757
class JoinTest extends AnyFunSuite with TaggedFilterSuite {
5858

5959
val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true)
60-
private implicit val tableUtils = TableUtils(spark)
60+
private implicit val tableUtils = TableTestUtils(spark)
6161

6262
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
6363
private val monthAgo = tableUtils.partitionSpec.minus(today, new Window(30, TimeUnit.DAYS))

spark/src/test/scala/ai/chronon/spark/test/LabelJoinTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class LabelJoinTest {
3737
private val namespace = "label_join"
3838
private val tableName = "test_label_join"
3939
private val labelDS = "2022-10-30"
40-
private val tableUtils = TableUtils(spark)
40+
private val tableUtils = TableTestUtils(spark)
4141
tableUtils.createDatabase(namespace)
4242

4343
private val viewsGroupBy = TestUtils.createViewsGroupBy(namespace, spark)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package ai.chronon.spark.test
2+
3+
import ai.chronon.spark.TableUtils
4+
import org.apache.spark.sql.SparkSession
5+
6+
case class TableTestUtils(override val sparkSession: SparkSession) extends TableUtils(sparkSession: SparkSession) {
7+
8+
def dropPartitions(tableName: String,
9+
partitions: Seq[String],
10+
partitionColumn: String = partitionColumn,
11+
subPartitionFilters: Map[String, String] = Map.empty): Unit = {
12+
if (partitions.nonEmpty && tableExists(tableName)) {
13+
val partitionSpecs = partitions
14+
.map { partition =>
15+
val mainSpec = s"$partitionColumn='$partition'"
16+
val specs = mainSpec +: subPartitionFilters.map {
17+
case (key, value) => s"$key='$value'"
18+
}.toSeq
19+
specs.mkString("PARTITION (", ",", ")")
20+
}
21+
.mkString(",")
22+
val dropSql = s"ALTER TABLE $tableName DROP IF EXISTS $partitionSpecs"
23+
sql(dropSql)
24+
} else {
25+
logger.info(s"$tableName doesn't exist, please double check before drop partitions")
26+
}
27+
}
28+
29+
def dropPartitionRange(tableName: String,
30+
startDate: String,
31+
endDate: String,
32+
subPartitionFilters: Map[String, String] = Map.empty): Unit = {
33+
if (tableExists(tableName)) {
34+
val toDrop = Stream.iterate(startDate)(partitionSpec.after).takeWhile(_ <= endDate)
35+
dropPartitions(tableName, toDrop, partitionColumn, subPartitionFilters)
36+
} else {
37+
logger.info(s"$tableName doesn't exist, please double check before drop partitions")
38+
}
39+
}
40+
41+
}

0 commit comments

Comments
 (0)