@@ -26,8 +26,9 @@ import ai.chronon.online.SparkConversions
26
26
import ai .chronon .online .TimeRange
27
27
import org .apache .avro .Schema
28
28
import org .apache .spark .sql .DataFrame
29
+ import org .apache .spark .sql .DataFrameReader
30
+ import org .apache .spark .sql .DataFrameWriter
29
31
import org .apache .spark .sql .Row
30
- import org .apache .spark .sql .SparkSession
31
32
import org .apache .spark .sql .catalyst .InternalRow
32
33
import org .apache .spark .sql .expressions .UserDefinedFunction
33
34
import org .apache .spark .sql .functions ._
@@ -322,53 +323,80 @@ object Extensions {
322
323
}
323
324
}
324
325
325
- implicit class DataPointerOps (dataPointer : DataPointer ) {
326
- def toDf (implicit sparkSession : SparkSession ): DataFrame = {
326
+ implicit class DataPointerAwareDataFrameWriter [T ](dfw : DataFrameWriter [T ]) {
327
+
328
+ def save (dataPointer : DataPointer ): Unit = {
329
+
330
+ dataPointer.writeFormat
331
+ .map((wf) => {
332
+ val normalized = wf.toLowerCase
333
+ normalized match {
334
+ case " bigquery" | " bq" =>
335
+ dfw
336
+ .format(" bigquery" )
337
+ .options(dataPointer.options)
338
+ .save(dataPointer.tableOrPath)
339
+ case " snowflake" | " sf" =>
340
+ dfw
341
+ .format(" net.snowflake.spark.snowflake" )
342
+ .options(dataPointer.options)
343
+ .option(" dbtable" , dataPointer.tableOrPath)
344
+ .save()
345
+ case " parquet" | " csv" =>
346
+ dfw
347
+ .format(normalized)
348
+ .options(dataPointer.options)
349
+ .save(dataPointer.tableOrPath)
350
+ case " hive" =>
351
+ dfw
352
+ .format(" hive" )
353
+ .saveAsTable(dataPointer.tableOrPath)
354
+ case _ =>
355
+ throw new UnsupportedOperationException (s " Unsupported write catalog: ${normalized}" )
356
+ }
357
+ })
358
+ .getOrElse(
359
+ // None case is just table against default catalog
360
+ dfw
361
+ .format(" hive" )
362
+ .saveAsTable(dataPointer.tableOrPath))
363
+ }
364
+ }
365
+
366
+ implicit class DataPointerAwareDataFrameReader (dfr : DataFrameReader ) {
367
+
368
+ def load (dataPointer : DataPointer ): DataFrame = {
327
369
val tableOrPath = dataPointer.tableOrPath
328
- val format = dataPointer.format.getOrElse(" parquet" )
329
- dataPointer.catalog.map(_.toLowerCase) match {
330
- case Some (" bigquery" ) | Some (" bq" ) =>
331
- // https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#reading-data-from-a-bigquery-table
332
- sparkSession.read
333
- .format(" bigquery" )
334
- .options(dataPointer.options)
335
- .load(tableOrPath)
336
-
337
- case Some (" snowflake" ) | Some (" sf" ) =>
338
- // https://docs.snowflake.com/en/user-guide/spark-connector-use#moving-data-from-snowflake-to-spark
339
- val sfOptions = dataPointer.options
340
- sparkSession.read
341
- .format(" net.snowflake.spark.snowflake" )
342
- .options(sfOptions)
343
- .option(" dbtable" , tableOrPath)
344
- .load()
345
-
346
- case Some (" s3" ) | Some (" s3a" ) | Some (" s3n" ) =>
347
- // https://sites.google.com/site/hellobenchen/home/wiki/big-data/spark/read-data-files-from-multiple-sub-folders
348
- // "To get spark to read through all subfolders and subsubfolders, etc. simply use the wildcard *"
349
- // "df= spark.read.parquet('/datafolder/*/*')"
350
- //
351
- // https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-plan-file-systems.html
352
- // "Previously, Amazon EMR used the s3n and s3a file systems. While both still work, "
353
- // "we recommend that you use the s3 URI scheme for the best performance, security, and reliability."
354
- // TODO: figure out how to scan subfolders in a date range without reading the entire folder
355
- sparkSession.read
356
- .format(format)
357
- .options(dataPointer.options)
358
- .load(" ș3://" + tableOrPath)
359
-
360
- case Some (" file" ) =>
361
- sparkSession.read
362
- .format(format)
363
- .options(dataPointer.options)
364
- .load(tableOrPath)
365
-
366
- case Some (" hive" ) | None =>
367
- sparkSession.table(tableOrPath)
368
-
369
- case _ =>
370
- throw new UnsupportedOperationException (s " Unsupported catalog: ${dataPointer.catalog}" )
371
- }
370
+
371
+ dataPointer.readFormat
372
+ .map((fmt) => {
373
+ val normalized = fmt.toLowerCase
374
+ normalized match {
375
+ case " bigquery" | " bq" =>
376
+ dfr
377
+ .format(" bigquery" )
378
+ .options(dataPointer.options)
379
+ .load(tableOrPath)
380
+ case " snowflake" | " sf" =>
381
+ dfr
382
+ .format(" net.snowflake.spark.snowflake" )
383
+ .options(dataPointer.options)
384
+ .option(" dbtable" , tableOrPath)
385
+ .load()
386
+ case " parquet" | " csv" =>
387
+ dfr
388
+ .format(normalized)
389
+ .options(dataPointer.options)
390
+ .load(tableOrPath)
391
+ case " hive" => dfr.table(tableOrPath)
392
+ case _ =>
393
+ throw new UnsupportedOperationException (s " Unsupported read catalog: ${normalized}" )
394
+ }
395
+ })
396
+ .getOrElse {
397
+ // None case is just table against default catalog
398
+ dfr.table(tableOrPath)
399
+ }
372
400
}
373
401
}
374
402
}
0 commit comments