Skip to content

feat: CatalogAwareDataPointer and refactoring existing DataPointer #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions api/src/main/scala/ai/chronon/api/DataPointer.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
package ai.chronon.api
import scala.util.parsing.combinator._

case class DataPointer(catalog: Option[String],
tableOrPath: String,
format: Option[String],
options: Map[String, String])
abstract class DataPointer {
def tableOrPath: String
def readFormat: Option[String]
def writeFormat: Option[String]
def options: Map[String, String]

}

case class URIDataPointer(
override val tableOrPath: String,
override val readFormat: Option[String],
override val writeFormat: Option[String],
override val options: Map[String, String]
) extends DataPointer

// parses string representations of data pointers
// ex: namespace.table
Expand All @@ -27,21 +37,26 @@ object DataPointer extends RegexParsers {
opt(catalogWithOptionalFormat ~ opt(options) ~ "://") ~ tableOrPath ^^ {
// format is specified in the prefix s3+parquet://bucket/path/to/data/*/*/
// note that if you have s3+parquet://bucket/path/to/data.csv, format is still parquet
case Some((ctl, Some(fmt)) ~ opts ~ _) ~ path =>
DataPointer(Some(ctl), path, Some(fmt), opts.getOrElse(Map.empty))
case Some((ctl, Some(fmt)) ~ opts ~ sep) ~ path =>
URIDataPointer(ctl + sep + path, Some(fmt), Some(fmt), opts.getOrElse(Map.empty))

// format is extracted from the path for relevant sources
// ex: s3://bucket/path/to/data.parquet
// ex: file://path/to/data.csv
// ex: hdfs://path/to/data.with.dots.parquet
// for other sources like bigquery, snowflake, format is None
case Some((ctl, None) ~ opts ~ _) ~ path =>
val (pathWithoutFormat, fmt) = extractFormatFromPath(path, ctl)
DataPointer(Some(ctl), path, fmt, opts.getOrElse(Map.empty))
case Some((ctl, None) ~ opts ~ sep) ~ path =>
val (_, fmt) = extractFormatFromPath(path, ctl)

fmt match {
// Retain the full uri if it's a path.
case Some(ft) => URIDataPointer(ctl + sep + path, Some(ft), Some(ft), opts.getOrElse(Map.empty))
case None => URIDataPointer(path, Some(ctl), Some(ctl), opts.getOrElse(Map.empty))
}

case None ~ path =>
// No prefix case (direct table reference)
DataPointer(None, path, None, Map.empty)
URIDataPointer(path, None, None, Map.empty)
}

private def catalogWithOptionalFormat: Parser[(String, Option[String])] =
Expand Down
31 changes: 17 additions & 14 deletions api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.chronon.api.test

import ai.chronon.api.DataPointer
import ai.chronon.api.URIDataPointer
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

Expand All @@ -9,60 +10,62 @@ DataPointerTest extends AnyFlatSpec with Matchers {

"DataPointer.apply" should "parse a simple s3 path" in {
val result = DataPointer("s3://bucket/path/to/data.parquet")
result should be(DataPointer(Some("s3"), "bucket/path/to/data.parquet", Some("parquet"), Map.empty))
result should be(URIDataPointer("s3://bucket/path/to/data.parquet", Some("parquet"), Some("parquet"), Map.empty))
}

it should "parse a bigquery table with options" in {
val result = DataPointer("bigquery(option1=value1,option2=value2)://project-id.dataset.table")
result should be(
DataPointer(Some("bigquery"),
"project-id.dataset.table",
None,
Map("option1" -> "value1", "option2" -> "value2")))
URIDataPointer("project-id.dataset.table",
Some("bigquery"),
Some("bigquery"),
Map("option1" -> "value1", "option2" -> "value2")))
}

it should "parse a bigquery table without options" in {
val result = DataPointer("bigquery://project-id.dataset.table")
result should be(DataPointer(Some("bigquery"), "project-id.dataset.table", None, Map.empty))
result should be(URIDataPointer("project-id.dataset.table", Some("bigquery"), Some("bigquery"), Map.empty))
}

it should "parse a kafka topic" in {
val result = DataPointer("kafka://my-topic")
result should be(DataPointer(Some("kafka"), "my-topic", None, Map.empty))
result should be(URIDataPointer("my-topic", Some("kafka"), Some("kafka"), Map.empty))
}

it should "parse a file path with format" in {
val result = DataPointer("file://path/to/data.csv")
result should be(DataPointer(Some("file"), "path/to/data.csv", Some("csv"), Map.empty))
result should be(URIDataPointer("file://path/to/data.csv", Some("csv"), Some("csv"), Map.empty))
}

it should "parse options with spaces" in {
val result = DataPointer("hive(key1 = value1, key2 = value2)://database.table")
result should be(DataPointer(Some("hive"), "database.table", None, Map("key1" -> "value1", "key2" -> "value2")))
result should be(
URIDataPointer("database.table", Some("hive"), Some("hive"), Map("key1" -> "value1", "key2" -> "value2")))
}

it should "handle paths with dots" in {
val result = DataPointer("hdfs://path/to/data.with.dots.parquet")
result should be(DataPointer(Some("hdfs"), "path/to/data.with.dots.parquet", Some("parquet"), Map.empty))
result should be(
URIDataPointer("hdfs://path/to/data.with.dots.parquet", Some("parquet"), Some("parquet"), Map.empty))
}

it should "handle paths with multiple dots and no format" in {
val result = DataPointer("file://path/to/data.with.dots")
result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("dots"), Map.empty))
result should be(URIDataPointer("file://path/to/data.with.dots", Some("dots"), Some("dots"), Map.empty))
}

it should "handle paths with multiple dots and prefixed format" in {
val result = DataPointer("file+csv://path/to/data.with.dots")
result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("csv"), Map.empty))
result should be(URIDataPointer("file://path/to/data.with.dots", Some("csv"), Some("csv"), Map.empty))
}

it should "handle paths with format and pointer to folder with glob matching" in {
val result = DataPointer("s3+parquet://path/to/*/*/")
result should be(DataPointer(Some("s3"), "path/to/*/*/", Some("parquet"), Map.empty))
result should be(URIDataPointer("s3://path/to/*/*/", Some("parquet"), Some("parquet"), Map.empty))
}

it should "handle no catalog, just table" in {
val result = DataPointer("namespace.table")
result should be(DataPointer(None, "namespace.table", None, Map.empty))
result should be(URIDataPointer("namespace.table", None, None, Map.empty))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider

case class BQuery(project: String) extends Format {

override def name: String = "bigquery"

override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
implicit sparkSession: SparkSession): Seq[String] =
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import org.apache.spark.sql.functions.url_decode

case class GCS(project: String) extends Format {

override def name: String = ""

override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
implicit sparkSession: SparkSession): Seq[String] =
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package ai.chronon.spark

import ai.chronon.api.DataPointer
import org.apache.spark.sql.SparkSession

import scala.reflect.runtime.universe._

case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: FormatProvider) extends DataPointer {

override def tableOrPath: String = {
formatProvider.resolveTableName(inputTableOrPath)
}
override lazy val options: Map[String, String] = Map.empty

override lazy val readFormat: Option[String] = {
Option(formatProvider.readFormat(inputTableOrPath)).map(_.name)
}

override lazy val writeFormat: Option[String] = {
Option(formatProvider.writeFormat(inputTableOrPath)).map(_.name)
}

}

object DataPointer {

def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
val clazzName =
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
val mirror = runtimeMirror(getClass.getClassLoader)
val classSymbol = mirror.staticClass(clazzName)
val classMirror = mirror.reflectClass(classSymbol)
val constructor = classSymbol.primaryConstructor.asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)
val reflected = constructorMirror(sparkSession)
val provider = reflected.asInstanceOf[FormatProvider]

CatalogAwareDataPointer(tableOrPath, provider)

}
Comment on lines +27 to +40
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for reflection failures.

Reflection could fail if class doesn't exist or lacks proper constructor.

Add try-catch:

 def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
+  try {
     val clazzName =
       sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
     val mirror = runtimeMirror(getClass.getClassLoader)
     val classSymbol = mirror.staticClass(clazzName)
     val classMirror = mirror.reflectClass(classSymbol)
     val constructor = classSymbol.primaryConstructor.asMethod
     val constructorMirror = classMirror.reflectConstructor(constructor)
     val reflected = constructorMirror(sparkSession)
     val provider = reflected.asInstanceOf[FormatProvider]

     CatalogAwareDataPointer(tableOrPath, provider)
+  } catch {
+    case e: Exception => throw new IllegalArgumentException(s"Failed to initialize FormatProvider: ${e.getMessage}", e)
+  }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
val clazzName =
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
val mirror = runtimeMirror(getClass.getClassLoader)
val classSymbol = mirror.staticClass(clazzName)
val classMirror = mirror.reflectClass(classSymbol)
val constructor = classSymbol.primaryConstructor.asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)
val reflected = constructorMirror(sparkSession)
val provider = reflected.asInstanceOf[FormatProvider]
CatalogAwareDataPointer(tableOrPath, provider)
}
def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
try {
val clazzName =
sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName)
val mirror = runtimeMirror(getClass.getClassLoader)
val classSymbol = mirror.staticClass(clazzName)
val classMirror = mirror.reflectClass(classSymbol)
val constructor = classSymbol.primaryConstructor.asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)
val reflected = constructorMirror(sparkSession)
val provider = reflected.asInstanceOf[FormatProvider]
CatalogAwareDataPointer(tableOrPath, provider)
} catch {
case e: Exception => throw new IllegalArgumentException(s"Failed to initialize FormatProvider: ${e.getMessage}", e)
}
}


}
6 changes: 3 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ object Extensions {
implicit class DataPointerOps(dataPointer: DataPointer) {
def toDf(implicit sparkSession: SparkSession): DataFrame = {
val tableOrPath = dataPointer.tableOrPath
val format = dataPointer.format.getOrElse("parquet")
dataPointer.catalog.map(_.toLowerCase) match {
val format = dataPointer.readFormat.getOrElse("parquet")
dataPointer.readFormat.map(_.toLowerCase) match {
case Some("bigquery") | Some("bq") =>
// https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#reading-data-from-a-bigquery-table
sparkSession.read
Expand Down Expand Up @@ -367,7 +367,7 @@ object Extensions {
sparkSession.table(tableOrPath)

case _ =>
throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.catalog}")
throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.readFormat}")
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/Format.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import scala.util.Try

trait Format {

def name: String

// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided
// If subpartition filters are supplied and the format doesn't support it, we throw an error
def primaryPartitions(tableName: String,
Expand Down Expand Up @@ -45,6 +47,7 @@ trait Format {

// Does this format support sub partitions filters
def supportSubPartitionsFilter: Boolean

}

/**
Expand All @@ -58,6 +61,8 @@ trait FormatProvider extends Serializable {
def sparkSession: SparkSession
def readFormat(tableName: String): Format
def writeFormat(tableName: String): Format

def resolveTableName(tableName: String) = tableName
}

/**
Expand Down Expand Up @@ -134,6 +139,8 @@ case class DefaultFormatProvider(sparkSession: SparkSession) extends FormatProvi
}

case object Hive extends Format {

override def name: String = "hive"
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
implicit sparkSession: SparkSession): Seq[String] =
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
Expand Down Expand Up @@ -167,6 +174,8 @@ case object Hive extends Format {
}

case object Iceberg extends Format {

override def name: String = "iceberg"
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
implicit sparkSession: SparkSession): Seq[String] = {
if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) {
Expand Down Expand Up @@ -216,6 +225,8 @@ case object Iceberg extends Format {
// In such cases, you should implement your own FormatProvider built on the newer Delta lake version
case object DeltaLake extends Format {

override def name: String = "delta"

override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(
implicit sparkSession: SparkSession): Seq[String] =
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
Expand Down
5 changes: 2 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package ai.chronon.spark
import ai.chronon.aggregator.windowing.TsUtils
import ai.chronon.api.ColorPrinter.ColorString
import ai.chronon.api.Constants
import ai.chronon.api.DataPointer
import ai.chronon.api.Extensions._
import ai.chronon.api.PartitionSpec
import ai.chronon.api.Query
Expand Down Expand Up @@ -747,13 +746,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
wheres: Seq[String],
rangeWheres: Seq[String],
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {
val dp = DataPointer(table)
val dp = ai.chronon.api.DataPointer.apply(table)
var df = dp.toDf(sparkSession)
val selects = QueryUtils.buildSelects(selectMap, fallbackSelects)
logger.info(s""" Scanning data:
| table: ${dp.tableOrPath.green}
| options: ${dp.options}
| format: ${dp.format}
| format: ${dp.readFormat}
| selects:
| ${selects.mkString("\n ").green}
| wheres:
Expand Down
Loading