Skip to content

Commit a41fdb4

Browse files
authored
feat: CatalogAwareDataPointer and refactoring existing DataPointer (#157)
## Summary - Refactor DataPointer. Ideally, if we are parsing URI's, if we come across a prefix we should preserve the original whole prefix. We lose out on the benefits of the various `s3<c>` uri's but we can fix that in a future iteration. This way, the Extensions code is simpler. - Define some DataframeWriter and DataframeReader implicit classes to support handling DataPointer. Ideally we want DataPointer to be a lightweight object that we can take action on, similar to what a table name or a uri is. - Introduce CatalogAwareDataPointer. The way this works is that it encapsulates `Format` which is a runtime injection used to figure out underlying storage r/w layers. This is ultimately what DataPointer represents, and instead of statically defining it we will make remote calls to do so. ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a more flexible `DataPointer` architecture with an abstract base class and `URIDataPointer`. - Added support for dynamic format resolution in Spark data sources. - Enhanced `BQuery` and `GCS` classes with specific `name` methods. - **Refactor** - Restructured `DataPointer` class to improve extensibility. - Enhanced format handling with standardized `name` methods for different data formats. - Updated `DataPointerOps` to streamline format and catalog handling. - Modified `TableUtils` to utilize the new `DataPointer` instantiation method. - **Improvements** - Implemented more robust table and format parsing mechanisms. - Added utility methods for resolving table names and formats. - Refined logging for `DataPointer` instantiation and state representation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` -->
1 parent 0aa2ec4 commit a41fdb4

File tree

8 files changed

+104
-30
lines changed

8 files changed

+104
-30
lines changed

api/src/main/scala/ai/chronon/api/DataPointer.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
package ai.chronon.api
22
import scala.util.parsing.combinator._
33

4-
case class DataPointer(catalog: Option[String],
5-
tableOrPath: String,
6-
format: Option[String],
7-
options: Map[String, String])
4+
abstract class DataPointer {
5+
def tableOrPath: String
6+
def readFormat: Option[String]
7+
def writeFormat: Option[String]
8+
def options: Map[String, String]
9+
10+
}
11+
12+
case class URIDataPointer(
13+
override val tableOrPath: String,
14+
override val readFormat: Option[String],
15+
override val writeFormat: Option[String],
16+
override val options: Map[String, String]
17+
) extends DataPointer
818

919
// parses string representations of data pointers
1020
// ex: namespace.table
@@ -27,21 +37,26 @@ object DataPointer extends RegexParsers {
2737
opt(catalogWithOptionalFormat ~ opt(options) ~ "://") ~ tableOrPath ^^ {
2838
// format is specified in the prefix s3+parquet://bucket/path/to/data/*/*/
2939
// note that if you have s3+parquet://bucket/path/to/data.csv, format is still parquet
30-
case Some((ctl, Some(fmt)) ~ opts ~ _) ~ path =>
31-
DataPointer(Some(ctl), path, Some(fmt), opts.getOrElse(Map.empty))
40+
case Some((ctl, Some(fmt)) ~ opts ~ sep) ~ path =>
41+
URIDataPointer(ctl + sep + path, Some(fmt), Some(fmt), opts.getOrElse(Map.empty))
3242

3343
// format is extracted from the path for relevant sources
3444
// ex: s3://bucket/path/to/data.parquet
3545
// ex: file://path/to/data.csv
3646
// ex: hdfs://path/to/data.with.dots.parquet
3747
// for other sources like bigquery, snowflake, format is None
38-
case Some((ctl, None) ~ opts ~ _) ~ path =>
39-
val (pathWithoutFormat, fmt) = extractFormatFromPath(path, ctl)
40-
DataPointer(Some(ctl), path, fmt, opts.getOrElse(Map.empty))
48+
case Some((ctl, None) ~ opts ~ sep) ~ path =>
49+
val (_, fmt) = extractFormatFromPath(path, ctl)
50+
51+
fmt match {
52+
// Retain the full uri if it's a path.
53+
case Some(ft) => URIDataPointer(ctl + sep + path, Some(ft), Some(ft), opts.getOrElse(Map.empty))
54+
case None => URIDataPointer(path, Some(ctl), Some(ctl), opts.getOrElse(Map.empty))
55+
}
4156

4257
case None ~ path =>
4358
// No prefix case (direct table reference)
44-
DataPointer(None, path, None, Map.empty)
59+
URIDataPointer(path, None, None, Map.empty)
4560
}
4661

4762
private def catalogWithOptionalFormat: Parser[(String, Option[String])] =
Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai.chronon.api.test
22

33
import ai.chronon.api.DataPointer
4+
import ai.chronon.api.URIDataPointer
45
import org.scalatest.flatspec.AnyFlatSpec
56
import org.scalatest.matchers.should.Matchers
67

@@ -9,60 +10,62 @@ DataPointerTest extends AnyFlatSpec with Matchers {
910

1011
"DataPointer.apply" should "parse a simple s3 path" in {
1112
val result = DataPointer("s3://bucket/path/to/data.parquet")
12-
result should be(DataPointer(Some("s3"), "bucket/path/to/data.parquet", Some("parquet"), Map.empty))
13+
result should be(URIDataPointer("s3://bucket/path/to/data.parquet", Some("parquet"), Some("parquet"), Map.empty))
1314
}
1415

1516
it should "parse a bigquery table with options" in {
1617
val result = DataPointer("bigquery(option1=value1,option2=value2)://project-id.dataset.table")
1718
result should be(
18-
DataPointer(Some("bigquery"),
19-
"project-id.dataset.table",
20-
None,
21-
Map("option1" -> "value1", "option2" -> "value2")))
19+
URIDataPointer("project-id.dataset.table",
20+
Some("bigquery"),
21+
Some("bigquery"),
22+
Map("option1" -> "value1", "option2" -> "value2")))
2223
}
2324

2425
it should "parse a bigquery table without options" in {
2526
val result = DataPointer("bigquery://project-id.dataset.table")
26-
result should be(DataPointer(Some("bigquery"), "project-id.dataset.table", None, Map.empty))
27+
result should be(URIDataPointer("project-id.dataset.table", Some("bigquery"), Some("bigquery"), Map.empty))
2728
}
2829

2930
it should "parse a kafka topic" in {
3031
val result = DataPointer("kafka://my-topic")
31-
result should be(DataPointer(Some("kafka"), "my-topic", None, Map.empty))
32+
result should be(URIDataPointer("my-topic", Some("kafka"), Some("kafka"), Map.empty))
3233
}
3334

3435
it should "parse a file path with format" in {
3536
val result = DataPointer("file://path/to/data.csv")
36-
result should be(DataPointer(Some("file"), "path/to/data.csv", Some("csv"), Map.empty))
37+
result should be(URIDataPointer("file://path/to/data.csv", Some("csv"), Some("csv"), Map.empty))
3738
}
3839

3940
it should "parse options with spaces" in {
4041
val result = DataPointer("hive(key1 = value1, key2 = value2)://database.table")
41-
result should be(DataPointer(Some("hive"), "database.table", None, Map("key1" -> "value1", "key2" -> "value2")))
42+
result should be(
43+
URIDataPointer("database.table", Some("hive"), Some("hive"), Map("key1" -> "value1", "key2" -> "value2")))
4244
}
4345

4446
it should "handle paths with dots" in {
4547
val result = DataPointer("hdfs://path/to/data.with.dots.parquet")
46-
result should be(DataPointer(Some("hdfs"), "path/to/data.with.dots.parquet", Some("parquet"), Map.empty))
48+
result should be(
49+
URIDataPointer("hdfs://path/to/data.with.dots.parquet", Some("parquet"), Some("parquet"), Map.empty))
4750
}
4851

4952
it should "handle paths with multiple dots and no format" in {
5053
val result = DataPointer("file://path/to/data.with.dots")
51-
result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("dots"), Map.empty))
54+
result should be(URIDataPointer("file://path/to/data.with.dots", Some("dots"), Some("dots"), Map.empty))
5255
}
5356

5457
it should "handle paths with multiple dots and prefixed format" in {
5558
val result = DataPointer("file+csv://path/to/data.with.dots")
56-
result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("csv"), Map.empty))
59+
result should be(URIDataPointer("file://path/to/data.with.dots", Some("csv"), Some("csv"), Map.empty))
5760
}
5861

5962
it should "handle paths with format and pointer to folder with glob matching" in {
6063
val result = DataPointer("s3+parquet://path/to/*/*/")
61-
result should be(DataPointer(Some("s3"), "path/to/*/*/", Some("parquet"), Map.empty))
64+
result should be(URIDataPointer("s3://path/to/*/*/", Some("parquet"), Some("parquet"), Map.empty))
6265
}
6366

6467
it should "handle no catalog, just table" in {
6568
val result = DataPointer("namespace.table")
66-
result should be(DataPointer(None, "namespace.table", None, Map.empty))
69+
result should be(URIDataPointer("namespace.table", None, None, Map.empty))
6770
}
6871
}

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
6666

6767
case class BQuery(project: String) extends Format {
6868

69+
override def name: String = "bigquery"
70+
6971
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
7072
implicit sparkSession: SparkSession): Seq[String] =
7173
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import org.apache.spark.sql.functions.url_decode
1010

1111
case class GCS(project: String) extends Format {
1212

13+
override def name: String = ""
14+
1315
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
1416
implicit sparkSession: SparkSession): Seq[String] =
1517
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package ai.chronon.spark
2+
3+
import ai.chronon.api.DataPointer
4+
import org.apache.spark.sql.SparkSession
5+
6+
import scala.reflect.runtime.universe._
7+
8+
case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: FormatProvider) extends DataPointer {
9+
10+
override def tableOrPath: String = {
11+
formatProvider.resolveTableName(inputTableOrPath)
12+
}
13+
override lazy val options: Map[String, String] = Map.empty
14+
15+
override lazy val readFormat: Option[String] = {
16+
Option(formatProvider.readFormat(inputTableOrPath)).map(_.name)
17+
}
18+
19+
override lazy val writeFormat: Option[String] = {
20+
Option(formatProvider.writeFormat(inputTableOrPath)).map(_.name)
21+
}
22+
23+
}
24+
25+
object DataPointer {
26+
27+
def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
28+
val clazzName =
29+
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
30+
val mirror = runtimeMirror(getClass.getClassLoader)
31+
val classSymbol = mirror.staticClass(clazzName)
32+
val classMirror = mirror.reflectClass(classSymbol)
33+
val constructor = classSymbol.primaryConstructor.asMethod
34+
val constructorMirror = classMirror.reflectConstructor(constructor)
35+
val reflected = constructorMirror(sparkSession)
36+
val provider = reflected.asInstanceOf[FormatProvider]
37+
38+
CatalogAwareDataPointer(tableOrPath, provider)
39+
40+
}
41+
42+
}

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ object Extensions {
299299
implicit class DataPointerOps(dataPointer: DataPointer) {
300300
def toDf(implicit sparkSession: SparkSession): DataFrame = {
301301
val tableOrPath = dataPointer.tableOrPath
302-
val format = dataPointer.format.getOrElse("parquet")
303-
dataPointer.catalog.map(_.toLowerCase) match {
302+
val format = dataPointer.readFormat.getOrElse("parquet")
303+
dataPointer.readFormat.map(_.toLowerCase) match {
304304
case Some("bigquery") | Some("bq") =>
305305
// https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#reading-data-from-a-bigquery-table
306306
sparkSession.read
@@ -341,7 +341,7 @@ object Extensions {
341341
sparkSession.table(tableOrPath)
342342

343343
case _ =>
344-
throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.catalog}")
344+
throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.readFormat}")
345345
}
346346
}
347347
}

spark/src/main/scala/ai/chronon/spark/Format.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import scala.util.Try
1010

1111
trait Format {
1212

13+
def name: String
14+
1315
// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided
1416
// If subpartition filters are supplied and the format doesn't support it, we throw an error
1517
def primaryPartitions(tableName: String,
@@ -45,6 +47,7 @@ trait Format {
4547

4648
// Does this format support sub partitions filters
4749
def supportSubPartitionsFilter: Boolean
50+
4851
}
4952

5053
/**
@@ -58,6 +61,8 @@ trait FormatProvider extends Serializable {
5861
def sparkSession: SparkSession
5962
def readFormat(tableName: String): Format
6063
def writeFormat(tableName: String): Format
64+
65+
def resolveTableName(tableName: String) = tableName
6166
}
6267

6368
/**
@@ -134,6 +139,8 @@ case class DefaultFormatProvider(sparkSession: SparkSession) extends FormatProvi
134139
}
135140

136141
case object Hive extends Format {
142+
143+
override def name: String = "hive"
137144
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
138145
implicit sparkSession: SparkSession): Seq[String] =
139146
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
@@ -167,6 +174,8 @@ case object Hive extends Format {
167174
}
168175

169176
case object Iceberg extends Format {
177+
178+
override def name: String = "iceberg"
170179
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
171180
implicit sparkSession: SparkSession): Seq[String] = {
172181
if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) {
@@ -216,6 +225,8 @@ case object Iceberg extends Format {
216225
// In such cases, you should implement your own FormatProvider built on the newer Delta lake version
217226
case object DeltaLake extends Format {
218227

228+
override def name: String = "delta"
229+
219230
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
220231
implicit sparkSession: SparkSession): Seq[String] =
221232
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package ai.chronon.spark
1919
import ai.chronon.aggregator.windowing.TsUtils
2020
import ai.chronon.api.ColorPrinter.ColorString
2121
import ai.chronon.api.Constants
22-
import ai.chronon.api.DataPointer
2322
import ai.chronon.api.Extensions._
2423
import ai.chronon.api.PartitionSpec
2524
import ai.chronon.api.Query
@@ -747,13 +746,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
747746
wheres: Seq[String],
748747
rangeWheres: Seq[String],
749748
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {
750-
val dp = DataPointer(table)
749+
val dp = ai.chronon.api.DataPointer.apply(table)
751750
var df = dp.toDf(sparkSession)
752751
val selects = QueryUtils.buildSelects(selectMap, fallbackSelects)
753752
logger.info(s""" Scanning data:
754753
| table: ${dp.tableOrPath.green}
755754
| options: ${dp.options}
756-
| format: ${dp.format}
755+
| format: ${dp.readFormat}
757756
| selects:
758757
| ${selects.mkString("\n ").green}
759758
| wheres:

0 commit comments

Comments
 (0)