diff --git a/api/src/main/scala/ai/chronon/api/DataPointer.scala b/api/src/main/scala/ai/chronon/api/DataPointer.scala index 713b5b3271..812f8a0e56 100644 --- a/api/src/main/scala/ai/chronon/api/DataPointer.scala +++ b/api/src/main/scala/ai/chronon/api/DataPointer.scala @@ -1,10 +1,20 @@ package ai.chronon.api import scala.util.parsing.combinator._ -case class DataPointer(catalog: Option[String], - tableOrPath: String, - format: Option[String], - options: Map[String, String]) +abstract class DataPointer { + def tableOrPath: String + def readFormat: Option[String] + def writeFormat: Option[String] + def options: Map[String, String] + +} + +case class URIDataPointer( + override val tableOrPath: String, + override val readFormat: Option[String], + override val writeFormat: Option[String], + override val options: Map[String, String] +) extends DataPointer // parses string representations of data pointers // ex: namespace.table @@ -27,21 +37,26 @@ object DataPointer extends RegexParsers { opt(catalogWithOptionalFormat ~ opt(options) ~ "://") ~ tableOrPath ^^ { // format is specified in the prefix s3+parquet://bucket/path/to/data/*/*/ // note that if you have s3+parquet://bucket/path/to/data.csv, format is still parquet - case Some((ctl, Some(fmt)) ~ opts ~ _) ~ path => - DataPointer(Some(ctl), path, Some(fmt), opts.getOrElse(Map.empty)) + case Some((ctl, Some(fmt)) ~ opts ~ sep) ~ path => + URIDataPointer(ctl + sep + path, Some(fmt), Some(fmt), opts.getOrElse(Map.empty)) // format is extracted from the path for relevant sources // ex: s3://bucket/path/to/data.parquet // ex: file://path/to/data.csv // ex: hdfs://path/to/data.with.dots.parquet // for other sources like bigquery, snowflake, format is None - case Some((ctl, None) ~ opts ~ _) ~ path => - val (pathWithoutFormat, fmt) = extractFormatFromPath(path, ctl) - DataPointer(Some(ctl), path, fmt, opts.getOrElse(Map.empty)) + case Some((ctl, None) ~ opts ~ sep) ~ path => + val (_, fmt) = extractFormatFromPath(path, ctl) + + fmt match { + // Retain the full uri if it's a path. + case Some(ft) => URIDataPointer(ctl + sep + path, Some(ft), Some(ft), opts.getOrElse(Map.empty)) + case None => URIDataPointer(path, Some(ctl), Some(ctl), opts.getOrElse(Map.empty)) + } case None ~ path => // No prefix case (direct table reference) - DataPointer(None, path, None, Map.empty) + URIDataPointer(path, None, None, Map.empty) } private def catalogWithOptionalFormat: Parser[(String, Option[String])] = diff --git a/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala b/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala index 6e1df2ae08..c4a2442acf 100644 --- a/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala @@ -1,6 +1,7 @@ package ai.chronon.api.test import ai.chronon.api.DataPointer +import ai.chronon.api.URIDataPointer import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -9,60 +10,62 @@ DataPointerTest extends AnyFlatSpec with Matchers { "DataPointer.apply" should "parse a simple s3 path" in { val result = DataPointer("s3://bucket/path/to/data.parquet") - result should be(DataPointer(Some("s3"), "bucket/path/to/data.parquet", Some("parquet"), Map.empty)) + result should be(URIDataPointer("s3://bucket/path/to/data.parquet", Some("parquet"), Some("parquet"), Map.empty)) } it should "parse a bigquery table with options" in { val result = DataPointer("bigquery(option1=value1,option2=value2)://project-id.dataset.table") result should be( - DataPointer(Some("bigquery"), - "project-id.dataset.table", - None, - Map("option1" -> "value1", "option2" -> "value2"))) + URIDataPointer("project-id.dataset.table", + Some("bigquery"), + Some("bigquery"), + Map("option1" -> "value1", "option2" -> "value2"))) } it should "parse a bigquery table without options" in { val result = DataPointer("bigquery://project-id.dataset.table") - result should be(DataPointer(Some("bigquery"), "project-id.dataset.table", None, Map.empty)) + result should be(URIDataPointer("project-id.dataset.table", Some("bigquery"), Some("bigquery"), Map.empty)) } it should "parse a kafka topic" in { val result = DataPointer("kafka://my-topic") - result should be(DataPointer(Some("kafka"), "my-topic", None, Map.empty)) + result should be(URIDataPointer("my-topic", Some("kafka"), Some("kafka"), Map.empty)) } it should "parse a file path with format" in { val result = DataPointer("file://path/to/data.csv") - result should be(DataPointer(Some("file"), "path/to/data.csv", Some("csv"), Map.empty)) + result should be(URIDataPointer("file://path/to/data.csv", Some("csv"), Some("csv"), Map.empty)) } it should "parse options with spaces" in { val result = DataPointer("hive(key1 = value1, key2 = value2)://database.table") - result should be(DataPointer(Some("hive"), "database.table", None, Map("key1" -> "value1", "key2" -> "value2"))) + result should be( + URIDataPointer("database.table", Some("hive"), Some("hive"), Map("key1" -> "value1", "key2" -> "value2"))) } it should "handle paths with dots" in { val result = DataPointer("hdfs://path/to/data.with.dots.parquet") - result should be(DataPointer(Some("hdfs"), "path/to/data.with.dots.parquet", Some("parquet"), Map.empty)) + result should be( + URIDataPointer("hdfs://path/to/data.with.dots.parquet", Some("parquet"), Some("parquet"), Map.empty)) } it should "handle paths with multiple dots and no format" in { val result = DataPointer("file://path/to/data.with.dots") - result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("dots"), Map.empty)) + result should be(URIDataPointer("file://path/to/data.with.dots", Some("dots"), Some("dots"), Map.empty)) } it should "handle paths with multiple dots and prefixed format" in { val result = DataPointer("file+csv://path/to/data.with.dots") - result should be(DataPointer(Some("file"), "path/to/data.with.dots", Some("csv"), Map.empty)) + result should be(URIDataPointer("file://path/to/data.with.dots", Some("csv"), Some("csv"), Map.empty)) } it should "handle paths with format and pointer to folder with glob matching" in { val result = DataPointer("s3+parquet://path/to/*/*/") - result should be(DataPointer(Some("s3"), "path/to/*/*/", Some("parquet"), Map.empty)) + result should be(URIDataPointer("s3://path/to/*/*/", Some("parquet"), Some("parquet"), Map.empty)) } it should "handle no catalog, just table" in { val result = DataPointer("namespace.table") - result should be(DataPointer(None, "namespace.table", None, Map.empty)) + result should be(URIDataPointer("namespace.table", None, None, Map.empty)) } } diff --git a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala index 107f5941e2..7f1432bfbd 100644 --- a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala +++ b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala @@ -66,6 +66,8 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider case class BQuery(project: String) extends Format { + override def name: String = "bigquery" + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( implicit sparkSession: SparkSession): Seq[String] = super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) diff --git a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala index 10482322aa..42ac7fd290 100644 --- a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala +++ b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala @@ -10,6 +10,8 @@ import org.apache.spark.sql.functions.url_decode case class GCS(project: String) extends Format { + override def name: String = "" + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( implicit sparkSession: SparkSession): Seq[String] = super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) diff --git a/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala b/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala new file mode 100644 index 0000000000..0dcc0dfc8c --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala @@ -0,0 +1,42 @@ +package ai.chronon.spark + +import ai.chronon.api.DataPointer +import org.apache.spark.sql.SparkSession + +import scala.reflect.runtime.universe._ + +case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: FormatProvider) extends DataPointer { + + override def tableOrPath: String = { + formatProvider.resolveTableName(inputTableOrPath) + } + override lazy val options: Map[String, String] = Map.empty + + override lazy val readFormat: Option[String] = { + Option(formatProvider.readFormat(inputTableOrPath)).map(_.name) + } + + override lazy val writeFormat: Option[String] = { + Option(formatProvider.writeFormat(inputTableOrPath)).map(_.name) + } + +} + +object DataPointer { + + def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = { + val clazzName = + sparkSession.conf.get("spark.chronon.table.format_provider.class", classOf[DefaultFormatProvider].getName) + val mirror = runtimeMirror(getClass.getClassLoader) + val classSymbol = mirror.staticClass(clazzName) + val classMirror = mirror.reflectClass(classSymbol) + val constructor = classSymbol.primaryConstructor.asMethod + val constructorMirror = classMirror.reflectConstructor(constructor) + val reflected = constructorMirror(sparkSession) + val provider = reflected.asInstanceOf[FormatProvider] + + CatalogAwareDataPointer(tableOrPath, provider) + + } + +} diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index 3018f28e7a..1cb46deb23 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -325,8 +325,8 @@ object Extensions { implicit class DataPointerOps(dataPointer: DataPointer) { def toDf(implicit sparkSession: SparkSession): DataFrame = { val tableOrPath = dataPointer.tableOrPath - val format = dataPointer.format.getOrElse("parquet") - dataPointer.catalog.map(_.toLowerCase) match { + val format = dataPointer.readFormat.getOrElse("parquet") + dataPointer.readFormat.map(_.toLowerCase) match { case Some("bigquery") | Some("bq") => // https://github.com/GoogleCloudDataproc/spark-bigquery-connector?tab=readme-ov-file#reading-data-from-a-bigquery-table sparkSession.read @@ -367,7 +367,7 @@ object Extensions { sparkSession.table(tableOrPath) case _ => - throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.catalog}") + throw new UnsupportedOperationException(s"Unsupported catalog: ${dataPointer.readFormat}") } } } diff --git a/spark/src/main/scala/ai/chronon/spark/Format.scala b/spark/src/main/scala/ai/chronon/spark/Format.scala index a1946c9ee8..fda8487c67 100644 --- a/spark/src/main/scala/ai/chronon/spark/Format.scala +++ b/spark/src/main/scala/ai/chronon/spark/Format.scala @@ -10,6 +10,8 @@ import scala.util.Try trait Format { + def name: String + // Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided // If subpartition filters are supplied and the format doesn't support it, we throw an error def primaryPartitions(tableName: String, @@ -45,6 +47,7 @@ trait Format { // Does this format support sub partitions filters def supportSubPartitionsFilter: Boolean + } /** @@ -58,6 +61,8 @@ trait FormatProvider extends Serializable { def sparkSession: SparkSession def readFormat(tableName: String): Format def writeFormat(tableName: String): Format + + def resolveTableName(tableName: String) = tableName } /** @@ -134,6 +139,8 @@ case class DefaultFormatProvider(sparkSession: SparkSession) extends FormatProvi } case object Hive extends Format { + + override def name: String = "hive" override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( implicit sparkSession: SparkSession): Seq[String] = super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) @@ -167,6 +174,8 @@ case object Hive extends Format { } case object Iceberg extends Format { + + override def name: String = "iceberg" override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( implicit sparkSession: SparkSession): Seq[String] = { if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) { @@ -216,6 +225,8 @@ case object Iceberg extends Format { // In such cases, you should implement your own FormatProvider built on the newer Delta lake version case object DeltaLake extends Format { + override def name: String = "delta" + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( implicit sparkSession: SparkSession): Seq[String] = super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 6fae79595a..903352fa43 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -19,7 +19,6 @@ package ai.chronon.spark import ai.chronon.aggregator.windowing.TsUtils import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants -import ai.chronon.api.DataPointer import ai.chronon.api.Extensions._ import ai.chronon.api.PartitionSpec import ai.chronon.api.Query @@ -747,13 +746,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable wheres: Seq[String], rangeWheres: Seq[String], fallbackSelects: Option[Map[String, String]] = None): DataFrame = { - val dp = DataPointer(table) + val dp = ai.chronon.api.DataPointer.apply(table) var df = dp.toDf(sparkSession) val selects = QueryUtils.buildSelects(selectMap, fallbackSelects) logger.info(s""" Scanning data: | table: ${dp.tableOrPath.green} | options: ${dp.options} - | format: ${dp.format} + | format: ${dp.readFormat} | selects: | ${selects.mkString("\n ").green} | wheres: