Skip to content

Commit 177c4ca

Browse files
committed
CatalogAwareDataPointer
1 parent 603c857 commit 177c4ca

File tree

4 files changed

+159
-71
lines changed

4 files changed

+159
-71
lines changed

api/src/main/scala/ai/chronon/api/DataPointer.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
package ai.chronon.api
22
import scala.util.parsing.combinator._
33

4-
case class DataPointer(catalog: Option[String],
5-
tableOrPath: String,
6-
format: Option[String],
7-
options: Map[String, String])
4+
abstract class DataPointer {
5+
def tableOrPath: String
6+
def readFormat: Option[String]
7+
def writeFormat: Option[String]
8+
def options: Map[String, String]
9+
10+
}
11+
12+
case class URIDataPointer(
13+
override val tableOrPath: String,
14+
override val readFormat: Option[String],
15+
override val writeFormat: Option[String],
16+
override val options: Map[String, String]
17+
) extends DataPointer
818

919
// parses string representations of data pointers
1020
// ex: namespace.table
@@ -27,21 +37,26 @@ object DataPointer extends RegexParsers {
2737
opt(catalogWithOptionalFormat ~ opt(options) ~ "://") ~ tableOrPath ^^ {
2838
// format is specified in the prefix s3+parquet://bucket/path/to/data/*/*/
2939
// note that if you have s3+parquet://bucket/path/to/data.csv, format is still parquet
30-
case Some((ctl, Some(fmt)) ~ opts ~ _) ~ path =>
31-
DataPointer(Some(ctl), path, Some(fmt), opts.getOrElse(Map.empty))
40+
case Some((ctl, Some(fmt)) ~ opts ~ sep) ~ path =>
41+
URIDataPointer(ctl + sep + path, Some(fmt), Some(fmt), opts.getOrElse(Map.empty))
3242

3343
// format is extracted from the path for relevant sources
3444
// ex: s3://bucket/path/to/data.parquet
3545
// ex: file://path/to/data.csv
3646
// ex: hdfs://path/to/data.with.dots.parquet
3747
// for other sources like bigquery, snowflake, format is None
38-
case Some((ctl, None) ~ opts ~ _) ~ path =>
39-
val (pathWithoutFormat, fmt) = extractFormatFromPath(path, ctl)
40-
DataPointer(Some(ctl), path, fmt, opts.getOrElse(Map.empty))
48+
case Some((ctl, None) ~ opts ~ sep) ~ path =>
49+
val (_, fmt) = extractFormatFromPath(path, ctl)
50+
51+
fmt match {
52+
// Retain the full uri if it's a path.
53+
case Some(ft) => URIDataPointer(ctl + sep + path, Some(ft), Some(ft), opts.getOrElse(Map.empty))
54+
case None => URIDataPointer(path, Some(ctl), Some(ctl), opts.getOrElse(Map.empty))
55+
}
4156

4257
case None ~ path =>
4358
// No prefix case (direct table reference)
44-
DataPointer(None, path, None, Map.empty)
59+
URIDataPointer(path, None, None, Map.empty)
4560
}
4661

4762
private def catalogWithOptionalFormat: Parser[(String, Option[String])] =
Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,70 @@
11
package ai.chronon.api.test
22

33
import ai.chronon.api.DataPointer
4+
import ai.chronon.api.URIDataPointer
45
import org.scalatest.flatspec.AnyFlatSpec
56
import org.scalatest.matchers.should.Matchers
67

78
class DataPointerTest extends AnyFlatSpec with Matchers {
89

910
"DataPointer.apply" should "parse a simple s3 path" in {
1011
val result = DataPointer("s3://bucket/path/to/data.parquet")
11-
result should be(DataPointer(Some("s3"), "bucket/path/to/data.parquet", Some("parquet"), Map.empty))
12+
result should be(URIDataPointer("s3://bucket/path/to/data.parquet", Some("parquet"), Some("parquet"), Map.empty))
1213
}
1314

1415
it should "parse a bigquery table with options" in {
1516
val result = DataPointer("bigquery(option1=value1,option2=value2)://project-id.dataset.table")
1617
result should be(
17-
DataPointer(Some("bigquery"),
18-
"project-id.dataset.table",
19-
None,
20-
Map("option1" -> "value1", "option2" -> "value2")))
18+
URIDataPointer("project-id.dataset.table",
19+
Some("bigquery"),
20+
Some("bigquery"),
21+
Map("option1" -> "value1", "option2" -> "value2")))
2122
}
2223

2324
it should "parse a bigquery table without options" in {
2425
val result = DataPointer("bigquery://project-id.dataset.table")
25-
result should be(DataPointer(Some("bigquery"), "project-id.dataset.table", None, Map.empty))
26+
result should be(URIDataPointer("project-id.dataset.table", Some("bigquery"), Some("bigquery"), Map.empty))
2627
}
2728

2829
it should "parse a kafka topic" in {
2930
val result = DataPointer("kafka://my-topic")
30-
result should be(DataPointer(Some("kafka"), "my-topic", None, Map.empty))
31+
result should be(URIDataPointer("my-topic", Some("kafka"), Some("kafka"), Map.empty))
3132
}
3233

3334
it should "parse a file path with format" in {
3435
val result = DataPointer("file://path/to/data.csv")
35-
result should be(DataPointer(Some("file"), "path/to/data.csv", Some("csv"), Map.empty))
36+
result should be(URIDataPointer("file://path/to/data.csv", Some("csv"), Some("csv"), Map.empty))
3637
}
3738

3839
it should "parse options with spaces" in {
3940
val result = DataPointer("hive(key1 = value1, key2 = value2)://database.table")
40-
result should be(DataPointer(Some("hive"), "database.table", None, Map("key1" -> "value1", "key2" -> "value2")))
41+
result should be(
42+
URIDataPointer("database.table", Some("hive"), Some("hive"), Map("key1" -> "value1", "key2" -> "value2")))
4143
}
4244

4345
it should "handle paths with dots" in {
4446
val result = DataPointer("hdfs://path/to/data.with.dots.parquet")
45-
result should be(DataPointer(Some("hdfs"), "path/to/data.with.dots.parquet", Some("parquet"), Map.empty))
47+
result should be(
48+
URIDataPointer("hdfs://path/to/data.with.dots.parquet", Some("parquet"), Some("parquet"), Map.empty))
4649
}
4750

4851
it should "handle paths with multiple dots and no format" in {
4952
val result = DataPointer("file://path/to/data.with.dots")
50-
result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("dots"), Map.empty))
53+
result should be(URIDataPointer("file://path/to/data.with.dots", Some("dots"), Some("dots"), Map.empty))
5154
}
5255

5356
it should "handle paths with multiple dots and prefixed format" in {
5457
val result = DataPointer("file+csv://path/to/data.with.dots")
55-
result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("csv"), Map.empty))
58+
result should be(URIDataPointer("file://path/to/data.with.dots", Some("csv"), Some("csv"), Map.empty))
5659
}
5760

5861
it should "handle paths with format and pointer to folder with glob matching" in {
5962
val result = DataPointer("s3+parquet://path/to/*/*/")
60-
result should be(DataPointer(Some("s3"), "path/to/*/*/", Some("parquet"), Map.empty))
63+
result should be(URIDataPointer("s3://path/to/*/*/", Some("parquet"), Some("parquet"), Map.empty))
6164
}
6265

6366
it should "handle no catalog, just table" in {
6467
val result = DataPointer("namespace.table")
65-
result should be(DataPointer(None, "namespace.table", None, Map.empty))
68+
result should be(URIDataPointer("namespace.table", None, None, Map.empty))
6669
}
6770
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package ai.chronon.spark
2+
3+
import ai.chronon.api.DataPointer
4+
import org.apache.spark.sql.SparkSession
5+
6+
import scala.reflect.runtime.universe._
7+
8+
case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: FormatProvider) extends DataPointer {
9+
10+
override def tableOrPath: String = {
11+
formatProvider.resolveTableName(inputTableOrPath)
12+
}
13+
override lazy val options: Map[String, String] = Map.empty
14+
15+
override lazy val readFormat: Option[String] = {
16+
Option(formatProvider.readFormat(inputTableOrPath)).map(_.name)
17+
}
18+
19+
override lazy val writeFormat: Option[String] = {
20+
Option(formatProvider.writeFormat(inputTableOrPath)).map(_.name)
21+
}
22+
23+
}
24+
25+
object DataPointer {
26+
27+
def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
28+
val clazzName =
29+
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
30+
val mirror = runtimeMirror(getClass.getClassLoader)
31+
val classSymbol = mirror.staticClass(clazzName)
32+
val classMirror = mirror.reflectClass(classSymbol)
33+
val constructor = classSymbol.primaryConstructor.asMethod
34+
val constructorMirror = classMirror.reflectConstructor(constructor)
35+
val reflected = constructorMirror(sparkSession)
36+
val provider = reflected.asInstanceOf[FormatProvider]
37+
38+
CatalogAwareDataPointer(tableOrPath, provider)
39+
40+
}
41+
42+
}

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ import ai.chronon.online.SparkConversions
2626
import ai.chronon.online.TimeRange
2727
import org.apache.avro.Schema
2828
import org.apache.spark.sql.DataFrame
29+
import org.apache.spark.sql.DataFrameReader
30+
import org.apache.spark.sql.DataFrameWriter
2931
import org.apache.spark.sql.Row
30-
import org.apache.spark.sql.SparkSession
3132
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.expressions.UserDefinedFunction
3334
import org.apache.spark.sql.functions._
@@ -322,53 +323,80 @@ object Extensions {
322323
}
323324
}
324325

325-
implicit class DataPointerOps(dataPointer: DataPointer) {
326-
def toDf(implicit sparkSession: SparkSession): DataFrame = {
326+
implicit class DataPointerAwareDataFrameWriter[T](dfw: DataFrameWriter[T]) {
327+
328+
def save(dataPointer: DataPointer): Unit = {
329+
330+
dataPointer.writeFormat
331+
.map((wf) => {
332+
val normalized = wf.toLowerCase
333+
normalized match {
334+
case "bigquery" | "bq" =>
335+
dfw
336+
.format("bigquery")
337+
.options(dataPointer.options)
338+
.save(dataPointer.tableOrPath)
339+
case "snowflake" | "sf" =>
340+
dfw
341+
.format("net.snowflake.spark.snowflake")
342+
.options(dataPointer.options)
343+
.option("dbtable", dataPointer.tableOrPath)
344+
.save()
345+
case "parquet" | "csv" =>
346+
dfw
347+
.format(normalized)
348+
.options(dataPointer.options)
349+
.save(dataPointer.tableOrPath)
350+
case "hive" =>
351+
dfw
352+
.format("hive")
353+
.saveAsTable(dataPointer.tableOrPath)
354+
case _ =>
355+
throw new UnsupportedOperationException(s"Unsupported write catalog: ${normalized}")
356+
}
357+
})
358+
.getOrElse(
359+
// None case is just table against default catalog
360+
dfw
361+
.format("hive")
362+
.saveAsTable(dataPointer.tableOrPath))
363+
}
364+
}
365+
366+
implicit class DataPointerAwareDataFrameReader(dfr: DataFrameReader) {
367+
368+
def load(dataPointer: DataPointer): DataFrame = {
327369
val tableOrPath = dataPointer.tableOrPath
328-
val format = dataPointer.format.getOrElse("parquet")
329-
dataPointer.catalog.map(_.toLowerCase) match {
330-
case Some("bigquery") | Some("bq") =>
331-
// https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#reading-data-from-a-bigquery-table
332-
sparkSession.read
333-
.format("bigquery")
334-
.options(dataPointer.options)
335-
.load(tableOrPath)
336-
337-
case Some("snowflake") | Some("sf") =>
338-
// https://docs.snowflake.com/en/user-guide/spark-connector-use#moving-data-from-snowflake-to-spark
339-
val sfOptions = dataPointer.options
340-
sparkSession.read
341-
.format("net.snowflake.spark.snowflake")
342-
.options(sfOptions)
343-
.option("dbtable", tableOrPath)
344-
.load()
345-
346-
case Some("s3") | Some("s3a") | Some("s3n") =>
347-
// https://sites.google.com/site/hellobenchen/home/wiki/big-data/spark/read-data-files-from-multiple-sub-folders
348-
// "To get spark to read through all subfolders and subsubfolders, etc. simply use the wildcard *"
349-
// "df= spark.read.parquet('/datafolder/*/*')"
350-
//
351-
// https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-plan-file-systems.html
352-
// "Previously, Amazon EMR used the s3n and s3a file systems. While both still work, "
353-
// "we recommend that you use the s3 URI scheme for the best performance, security, and reliability."
354-
// TODO: figure out how to scan subfolders in a date range without reading the entire folder
355-
sparkSession.read
356-
.format(format)
357-
.options(dataPointer.options)
358-
.load("ș3://" + tableOrPath)
359-
360-
case Some("file") =>
361-
sparkSession.read
362-
.format(format)
363-
.options(dataPointer.options)
364-
.load(tableOrPath)
365-
366-
case Some("hive") | None =>
367-
sparkSession.table(tableOrPath)
368-
369-
case _ =>
370-
throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.catalog}")
371-
}
370+
371+
dataPointer.readFormat
372+
.map((fmt) => {
373+
val normalized = fmt.toLowerCase
374+
normalized match {
375+
case "bigquery" | "bq" =>
376+
dfr
377+
.format("bigquery")
378+
.options(dataPointer.options)
379+
.load(tableOrPath)
380+
case "snowflake" | "sf" =>
381+
dfr
382+
.format("net.snowflake.spark.snowflake")
383+
.options(dataPointer.options)
384+
.option("dbtable", tableOrPath)
385+
.load()
386+
case "parquet" | "csv" =>
387+
dfr
388+
.format(normalized)
389+
.options(dataPointer.options)
390+
.load(tableOrPath)
391+
case "hive" => dfr.table(tableOrPath)
392+
case _ =>
393+
throw new UnsupportedOperationException(s"Unsupported read catalog: ${normalized}")
394+
}
395+
})
396+
.getOrElse {
397+
// None case is just table against default catalog
398+
dfr.table(tableOrPath)
399+
}
372400
}
373401
}
374402
}

0 commit comments

Comments
 (0)