Skip to content

Commit d0deb93

Browse files
committed
extensions
1 parent 14f7ef9 commit d0deb93

File tree

1 file changed

+75
-47
lines changed

1 file changed

+75
-47
lines changed

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ import ai.chronon.online.SparkConversions
2626
import ai.chronon.online.TimeRange
2727
import org.apache.avro.Schema
2828
import org.apache.spark.sql.DataFrame
29+
import org.apache.spark.sql.DataFrameReader
30+
import org.apache.spark.sql.DataFrameWriter
2931
import org.apache.spark.sql.Row
30-
import org.apache.spark.sql.SparkSession
3132
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.expressions.UserDefinedFunction
3334
import org.apache.spark.sql.functions._
@@ -322,53 +323,80 @@ object Extensions {
322323
}
323324
}
324325

325-
implicit class DataPointerOps(dataPointer: DataPointer) {
326-
def toDf(implicit sparkSession: SparkSession): DataFrame = {
326+
implicit class DataPointerAwareDataFrameWriter[T](dfw: DataFrameWriter[T]) {
327+
328+
def save(dataPointer: DataPointer): Unit = {
329+
330+
dataPointer.writeFormat
331+
.map((wf) => {
332+
val normalized = wf.toLowerCase
333+
normalized match {
334+
case "bigquery" | "bq" =>
335+
dfw
336+
.format("bigquery")
337+
.options(dataPointer.options)
338+
.save(dataPointer.tableOrPath)
339+
case "snowflake" | "sf" =>
340+
dfw
341+
.format("net.snowflake.spark.snowflake")
342+
.options(dataPointer.options)
343+
.option("dbtable", dataPointer.tableOrPath)
344+
.save()
345+
case "parquet" | "csv" =>
346+
dfw
347+
.format(normalized)
348+
.options(dataPointer.options)
349+
.save(dataPointer.tableOrPath)
350+
case "hive" =>
351+
dfw
352+
.format("hive")
353+
.saveAsTable(dataPointer.tableOrPath)
354+
case _ =>
355+
throw new UnsupportedOperationException(s"Unsupported write catalog: ${normalized}")
356+
}
357+
})
358+
.getOrElse(
359+
// None case is just table against default catalog
360+
dfw
361+
.format("hive")
362+
.saveAsTable(dataPointer.tableOrPath))
363+
}
364+
}
365+
366+
implicit class DataPointerAwareDataFrameReader(dfr: DataFrameReader) {
367+
368+
def load(dataPointer: DataPointer): DataFrame = {
327369
val tableOrPath = dataPointer.tableOrPath
328-
val format = dataPointer.format.getOrElse("parquet")
329-
dataPointer.catalog.map(_.toLowerCase) match {
330-
case Some("bigquery") | Some("bq") =>
331-
// https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#reading-data-from-a-bigquery-table
332-
sparkSession.read
333-
.format("bigquery")
334-
.options(dataPointer.options)
335-
.load(tableOrPath)
336-
337-
case Some("snowflake") | Some("sf") =>
338-
// https://docs.snowflake.com/en/user-guide/spark-connector-use#moving-data-from-snowflake-to-spark
339-
val sfOptions = dataPointer.options
340-
sparkSession.read
341-
.format("net.snowflake.spark.snowflake")
342-
.options(sfOptions)
343-
.option("dbtable", tableOrPath)
344-
.load()
345-
346-
case Some("s3") | Some("s3a") | Some("s3n") =>
347-
// https://sites.google.com/site/hellobenchen/home/wiki/big-data/spark/read-data-files-from-multiple-sub-folders
348-
// "To get spark to read through all subfolders and subsubfolders, etc. simply use the wildcard *"
349-
// "df= spark.read.parquet('/datafolder/*/*')"
350-
//
351-
// https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-plan-file-systems.html
352-
// "Previously, Amazon EMR used the s3n and s3a file systems. While both still work, "
353-
// "we recommend that you use the s3 URI scheme for the best performance, security, and reliability."
354-
// TODO: figure out how to scan subfolders in a date range without reading the entire folder
355-
sparkSession.read
356-
.format(format)
357-
.options(dataPointer.options)
358-
.load("ș3://" + tableOrPath)
359-
360-
case Some("file") =>
361-
sparkSession.read
362-
.format(format)
363-
.options(dataPointer.options)
364-
.load(tableOrPath)
365-
366-
case Some("hive") | None =>
367-
sparkSession.table(tableOrPath)
368-
369-
case _ =>
370-
throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.catalog}")
371-
}
370+
371+
dataPointer.readFormat
372+
.map((fmt) => {
373+
val normalized = fmt.toLowerCase
374+
normalized match {
375+
case "bigquery" | "bq" =>
376+
dfr
377+
.format("bigquery")
378+
.options(dataPointer.options)
379+
.load(tableOrPath)
380+
case "snowflake" | "sf" =>
381+
dfr
382+
.format("net.snowflake.spark.snowflake")
383+
.options(dataPointer.options)
384+
.option("dbtable", tableOrPath)
385+
.load()
386+
case "parquet" | "csv" =>
387+
dfr
388+
.format(normalized)
389+
.options(dataPointer.options)
390+
.load(tableOrPath)
391+
case "hive" => dfr.table(tableOrPath)
392+
case _ =>
393+
throw new UnsupportedOperationException(s"Unsupported read catalog: ${normalized}")
394+
}
395+
})
396+
.getOrElse {
397+
// None case is just table against default catalog
398+
dfr.table(tableOrPath)
399+
}
372400
}
373401
}
374402
}

0 commit comments

Comments
 (0)