Skip to content

Commit 3271f72

Browse files
pr feeedback
Co-authored-by: Thomas Chow <[email protected]>
1 parent b2ff399 commit 3271f72

File tree

7 files changed

+120
-10
lines changed

7 files changed

+120
-10
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import com.google.cloud.bigquery.Table
1111
import com.google.cloud.bigquery.connector.common.BigQueryUtil
1212
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableId
1313
import org.apache.spark.sql.SparkSession
14+
import org.apache.spark.sql.functions.{col, to_date}
1415

1516
import scala.collection.JavaConverters._
1617

@@ -87,7 +88,6 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
8788
* case None => throw new IllegalStateException("Dataproc federation service must be available.")
8889
*
8990
* }
90-
* }
9191
*
9292
* case Some("hive") | None => Hive
9393
* }
@@ -151,6 +151,13 @@ case class BQuery(project: String) extends Format {
151151
.option("project", project)
152152
.option("query", partValsSql)
153153
.load()
154+
.select(
155+
to_date(col("partition_id"),
156+
"yyyyMMdd"
157+
) // Note: this "yyyyMMdd" format is hardcoded but we need to change it to be something else.
158+
.as("partition_id"))
159+
.na // Should filter out '__NULL__' and '__UNPARTITIONED__'. See: https://cloud.google.com/bigquery/docs/partitioned-tables#date_timestamp_partitioned_tables
160+
.drop()
154161
.as[String]
155162
.collect
156163
.toList

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ import org.apache.spark.sql.SparkSession
77
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
88
import org.apache.spark.sql.execution.FileSourceScanExec
99
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
10+
case class GCS(project: String, sourceUri: String, fileFormat: String) extends Format {
1011

11-
case class GCS(project: String, sourceUri: String, format: String) extends Format {
12-
13-
override def name: String = format
12+
override def name: String = fileFormat
1413

1514
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
1615
implicit sparkSession: SparkSession): Seq[String] =
@@ -39,7 +38,8 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
3938
*
4039
*/
4140
val partitionSpec = sparkSession.read
42-
.parquet(sourceUri)
41+
.format(fileFormat)
42+
.load(sourceUri)
4343
.queryExecution
4444
.sparkPlan
4545
.asInstanceOf[FileSourceScanExec]
@@ -52,7 +52,12 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
5252
val partitions = partitionSpec.partitions.map(_.values)
5353

5454
val deserializer =
55-
Encoders.row(partitionColumns).asInstanceOf[ExpressionEncoder[Row]].resolveAndBind().createDeserializer()
55+
try {
56+
Encoders.row(partitionColumns).asInstanceOf[ExpressionEncoder[Row]].resolveAndBind().createDeserializer()
57+
} catch {
58+
case e: Exception =>
59+
throw new RuntimeException(s"Failed to create deserializer for partition columns: ${e.getMessage}", e)
60+
}
5661

5762
val roundTripped = sparkSession
5863
.createDataFrame(sparkSession.sparkContext.parallelize(partitions.map(deserializer)), partitionColumns)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package ai.chronon.integrations.cloud_gcp
2+
3+
import ai.chronon.spark.SparkSessionBuilder
4+
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
5+
import org.apache.spark.sql.functions._
6+
import org.apache.spark.sql.types.{StringType, StructField, StructType}
7+
8+
import org.junit.Assert.assertEquals
9+
import org.scalatest.funsuite.AnyFunSuite
10+
11+
import java.nio.file.Files
12+
13+
class GCSFormatTest extends AnyFunSuite {
14+
15+
lazy val spark: SparkSession = SparkSessionBuilder.build(
16+
"BigQuerySparkTest",
17+
local = true
18+
)
19+
20+
test("partitions method should return correctly parsed partitions as maps") {
21+
22+
val testData = List(
23+
("20241223", "b", "c"),
24+
("20241224", "e", "f"),
25+
("20241225", "h", "i")
26+
)
27+
28+
val dir = Files.createTempDirectory("spark-test-output").toFile
29+
dir.deleteOnExit()
30+
31+
val df = spark.createDataFrame(testData).toDF("ds", "first", "second")
32+
df.write.partitionBy("ds").format("parquet").mode(SaveMode.Overwrite).save(dir.getAbsolutePath)
33+
val gcsFormat = GCS(project = "test-project", sourceUri = dir.getAbsolutePath, fileFormat = "parquet")
34+
val partitions = gcsFormat.partitions("unused_table")(spark)
35+
36+
assertEquals(Set(Map("ds" -> "20241223"), Map("ds" -> "20241224"), Map("ds" -> "20241225")), partitions.toSet)
37+
38+
}
39+
40+
test("partitions method should handle empty partitions gracefully") {
41+
42+
val testData = List(
43+
("20241223", "b", "c"),
44+
("20241224", "e", "f"),
45+
("20241225", "h", "i")
46+
)
47+
48+
val dir = Files.createTempDirectory("spark-test-output").toFile
49+
dir.deleteOnExit()
50+
51+
val df = spark.createDataFrame(testData).toDF("ds", "first", "second")
52+
df.write.format("parquet").mode(SaveMode.Overwrite).save(dir.getAbsolutePath)
53+
val gcsFormat = GCS(project = "test-project", sourceUri = dir.getAbsolutePath, fileFormat = "parquet")
54+
val partitions = gcsFormat.partitions("unused_table")(spark)
55+
56+
assertEquals(Set.empty, partitions.toSet)
57+
58+
}
59+
60+
test("partitions method should handle date types") {
61+
val testData = List(
62+
Row("2024-12-23", "b", "c"),
63+
Row("2024-12-24", "e", "f"),
64+
Row("2024-12-25", "h", "i")
65+
)
66+
67+
val dir = Files.createTempDirectory("spark-test-output").toFile
68+
dir.deleteOnExit()
69+
70+
val schema = StructType(
71+
Seq(
72+
StructField("ds", StringType, nullable = true),
73+
StructField("first", StringType, nullable = true),
74+
StructField("second", StringType, nullable = true)
75+
))
76+
77+
val df =
78+
spark
79+
.createDataFrame(spark.sparkContext.parallelize(testData), schema)
80+
.toDF("ds", "first", "second")
81+
.select(to_date(col("ds"), "yyyy-MM-dd").as("ds"), col("first"), col("second"))
82+
df.write.format("parquet").partitionBy("ds").mode(SaveMode.Overwrite).save(dir.getAbsolutePath)
83+
val gcsFormat = GCS(project = "test-project", sourceUri = dir.getAbsolutePath, fileFormat = "parquet")
84+
val partitions = gcsFormat.partitions("unused_table")(spark)
85+
86+
assertEquals(Set(Map("ds" -> "2024-12-23"), Map("ds" -> "2024-12-24"), Map("ds" -> "2024-12-25")), partitions.toSet)
87+
88+
}
89+
}

spark/src/main/scala/ai/chronon/spark/Driver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ object Driver {
278278
val join = new Join(
279279
args.joinConf,
280280
args.endDate(),
281-
args.buildTableUtils(),
281+
tableUtils,
282282
!args.runFirstHole(),
283283
selectedJoinParts = args.selectedJoinParts.toOption
284284
)

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ object Extensions {
309309
dfw
310310
.format("bigquery")
311311
.options(dataPointer.options)
312+
.option("writeMethod", "direct")
312313
.save(dataPointer.tableOrPath)
313314
case "snowflake" | "sf" =>
314315
dfw

spark/src/main/scala/ai/chronon/spark/JoinBase.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ abstract class JoinBase(joinConf: api.Join,
441441
if (tableUtils.backfillValidationEnforced) throw ex
442442
case e: Throwable =>
443443
metrics.gauge(Metrics.Name.validationFailure, 1)
444-
logger.error(s"An unexpected error occurred during validation. ${e.getMessage}")
444+
throw e
445445
}
446446

447447
// First run command to archive tables that have changed semantically since the last run
@@ -494,7 +494,11 @@ abstract class JoinBase(joinConf: api.Join,
494494
val runSmallMode = {
495495
if (tableUtils.smallModelEnabled) {
496496
val thresholdCount =
497-
leftDf(joinConf, wholeRange, tableUtils, limit = Some(tableUtils.smallModeNumRowsCutoff + 1)).get.count()
497+
leftDf(joinConf,
498+
wholeRange,
499+
tableUtils,
500+
allowEmpty = true,
501+
limit = Some(tableUtils.smallModeNumRowsCutoff + 1)).get.count()
498502
val result = thresholdCount <= tableUtils.smallModeNumRowsCutoff
499503
if (result) {
500504
logger.info(s"Counted $thresholdCount rows, running join in small mode.")

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,11 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
233233
try {
234234
// retrieve one row from the table
235235
val partitionFilter = lastAvailablePartition(tableName).getOrElse(fallbackPartition)
236-
sparkSession.sql(s"SELECT * FROM $tableName where $partitionColumn='$partitionFilter' LIMIT 1").collect()
236+
sparkSession.read
237+
.load(DataPointer(tableName, sparkSession))
238+
.where(s"$partitionColumn='$partitionFilter'")
239+
.limit(1)
240+
.collect()
237241
true
238242
} catch {
239243
case e: SparkException =>

0 commit comments

Comments
 (0)