Skip to content

Commit 6f228f2

Browse files
pr feedback
Co-authored-by: Thomas Chow <[email protected]>
1 parent b2ff399 commit 6f228f2

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import com.google.cloud.bigquery.Table
1111
import com.google.cloud.bigquery.connector.common.BigQueryUtil
1212
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableId
1313
import org.apache.spark.sql.SparkSession
14+
import org.apache.spark.sql.functions.{col, to_date}
1415

1516
import scala.collection.JavaConverters._
1617

@@ -87,7 +88,6 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
8788
* case None => throw new IllegalStateException("Dataproc federation service must be available.")
8889
*
8990
* }
90-
* }
9191
*
9292
* case Some("hive") | None => Hive
9393
* }
@@ -151,6 +151,13 @@ case class BQuery(project: String) extends Format {
151151
.option("project", project)
152152
.option("query", partValsSql)
153153
.load()
154+
.select(
155+
to_date(col("partition_id"),
156+
"yyyyMMdd"
157+
) // Note: this "yyyyMMdd" format is hardcoded but we need to change it to be something else.
158+
.as("partition_id"))
159+
.na // Should filter out '__NULL__' and '__UNPARTITIONED__'. See: https://cloud.google.com/bigquery/docs/partitioned-tables#date_timestamp_partitioned_tables
160+
.drop()
154161
.as[String]
155162
.collect
156163
.toList

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ import org.apache.spark.sql.SparkSession
77
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
88
import org.apache.spark.sql.execution.FileSourceScanExec
99
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
10+
case class GCS(project: String, sourceUri: String, fileFormat: String) extends Format {
1011

11-
case class GCS(project: String, sourceUri: String, format: String) extends Format {
12-
13-
override def name: String = format
12+
override def name: String = fileFormat
1413

1514
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
1615
implicit sparkSession: SparkSession): Seq[String] =
@@ -39,7 +38,8 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
3938
*
4039
*/
4140
val partitionSpec = sparkSession.read
42-
.parquet(sourceUri)
41+
.format(fileFormat)
42+
.load(sourceUri)
4343
.queryExecution
4444
.sparkPlan
4545
.asInstanceOf[FileSourceScanExec]
@@ -52,7 +52,12 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
5252
val partitions = partitionSpec.partitions.map(_.values)
5353

5454
val deserializer =
55-
Encoders.row(partitionColumns).asInstanceOf[ExpressionEncoder[Row]].resolveAndBind().createDeserializer()
55+
try {
56+
Encoders.row(partitionColumns).asInstanceOf[ExpressionEncoder[Row]].resolveAndBind().createDeserializer()
57+
} catch {
58+
case e: Exception =>
59+
throw new RuntimeException(s"Failed to create deserializer for partition columns: ${e.getMessage}", e)
60+
}
5661

5762
val roundTripped = sparkSession
5863
.createDataFrame(sparkSession.sparkContext.parallelize(partitions.map(deserializer)), partitionColumns)

spark/src/main/scala/ai/chronon/spark/Driver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ object Driver {
278278
val join = new Join(
279279
args.joinConf,
280280
args.endDate(),
281-
args.buildTableUtils(),
281+
tableUtils,
282282
!args.runFirstHole(),
283283
selectedJoinParts = args.selectedJoinParts.toOption
284284
)

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ object Extensions {
309309
dfw
310310
.format("bigquery")
311311
.options(dataPointer.options)
312+
.option("writeMethod", "direct")
312313
.save(dataPointer.tableOrPath)
313314
case "snowflake" | "sf" =>
314315
dfw

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,11 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
233233
try {
234234
// retrieve one row from the table
235235
val partitionFilter = lastAvailablePartition(tableName).getOrElse(fallbackPartition)
236-
sparkSession.sql(s"SELECT * FROM $tableName where $partitionColumn='$partitionFilter' LIMIT 1").collect()
236+
sparkSession.read
237+
.load(DataPointer(tableName, sparkSession))
238+
.where(s"$partitionColumn='$partitionFilter'")
239+
.limit(1)
240+
.collect()
237241
true
238242
} catch {
239243
case e: SparkException =>

0 commit comments

Comments
 (0)