diff --git a/.plugin-versions b/.plugin-versions index b6a761e8a3..3523520299 100644 --- a/.plugin-versions +++ b/.plugin-versions @@ -1,5 +1,7 @@ asdf-plugin-manager https://github.com/asdf-community/asdf-plugin-manager.git b5862c1 +gcloud https://github.com/jthegedus/asdf-gcloud.git 00cdf06 java https://github.com/halcyon/asdf-java.git 0ec69b2 python https://github.com/danhper/asdf-python.git a3a0185 sbt https://github.com/lerencao/asdf-sbt 53c9f4b scala https://github.com/asdf-community/asdf-scala.git 0533444 +thrift https://github.com/alisaifee/asdf-thrift.git fecdd6c diff --git a/.tool-versions b/.tool-versions index 6eccdf9548..8dae302a9c 100644 --- a/.tool-versions +++ b/.tool-versions @@ -5,3 +5,4 @@ sbt 1.8.2 python 3.7.17 3.11.0 +gcloud 504.0.1 diff --git a/api/src/main/scala/ai/chronon/api/DataPointer.scala b/api/src/main/scala/ai/chronon/api/DataPointer.scala index 812f8a0e56..d6aefffbd1 100644 --- a/api/src/main/scala/ai/chronon/api/DataPointer.scala +++ b/api/src/main/scala/ai/chronon/api/DataPointer.scala @@ -5,7 +5,9 @@ abstract class DataPointer { def tableOrPath: String def readFormat: Option[String] def writeFormat: Option[String] - def options: Map[String, String] + + def readOptions: Map[String, String] + def writeOptions: Map[String, String] } @@ -13,8 +15,12 @@ case class URIDataPointer( override val tableOrPath: String, override val readFormat: Option[String], override val writeFormat: Option[String], - override val options: Map[String, String] -) extends DataPointer + options: Map[String, String] +) extends DataPointer { + + override val readOptions: Map[String, String] = options + override val writeOptions: Map[String, String] = options +} // parses string representations of data pointers // ex: namespace.table diff --git a/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala b/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala index c4a2442acf..d1d7f08a54 100644 --- a/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala +++ b/api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala @@ -5,8 +5,7 @@ import ai.chronon.api.URIDataPointer import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -class -DataPointerTest extends AnyFlatSpec with Matchers { +class DataPointerTest extends AnyFlatSpec with Matchers { "DataPointer.apply" should "parse a simple s3 path" in { val result = DataPointer("s3://bucket/path/to/data.parquet") diff --git a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala index 11af422f01..00262a3d77 100644 --- a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala +++ b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GcpFormatProvider.scala @@ -32,17 +32,21 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider override def readFormat(tableName: String): Format = format(tableName) override def writeFormat(table: String): Format = { + val tableId = BigQueryUtil.parseTableId(table) + assert(scala.Option(tableId.getProject).isDefined, s"project required for ${table}") + assert(scala.Option(tableId.getDataset).isDefined, s"dataset required for ${table}") val tu = TableUtils(sparkSession) + val partitionColumnOption = + if (tu.tableReachable(table)) Map.empty else Map("partitionField" -> tu.partitionColumn) val sparkOptions: Map[String, String] = Map( - "partitionField" -> tu.partitionColumn, // todo(tchow): No longer needed after https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/1320 "temporaryGcsBucket" -> sparkSession.conf.get("spark.chronon.table.gcs.temporary_gcs_bucket"), "writeMethod" -> "indirect" - ) + ) ++ partitionColumnOption - BigQueryFormat(bqOptions.getProjectId, sparkOptions) + BigQueryFormat(tableId.getProject, sparkOptions) } private def getFormat(table: Table): Format = @@ -72,7 +76,10 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider val table = bigQueryClient.getTable(btTableIdentifier.getDataset, btTableIdentifier.getTable) // lookup bq for the table, if not fall back to hive - scala.Option(table).map(getFormat).getOrElse(Hive) + scala + .Option(table) + .map(getFormat) + .getOrElse(scala.Option(btTableIdentifier.getProject).map(BigQueryFormat(_, Map.empty)).getOrElse(Hive)) } } diff --git a/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala b/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala index 8c010b3523..f8a5098a0e 100644 --- a/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala +++ b/spark/src/main/scala/ai/chronon/spark/CatalogAwareDataPointer.scala @@ -10,10 +10,12 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For formatProvider.resolveTableName(inputTableOrPath) } - override lazy val options: Map[String, String] = { - // Hack for now, include both read and write options for the datapointer. - // todo(tchow): rework this abstraction. https://app.asana.com/0/1208785567265389/1209026103291854/f - formatProvider.readFormat(inputTableOrPath).options ++ formatProvider.writeFormat(inputTableOrPath).options + override lazy val readOptions: Map[String, String] = { + formatProvider.readFormat(inputTableOrPath).options + } + + override lazy val writeOptions: Map[String, String] = { + formatProvider.writeFormat(inputTableOrPath).options } override lazy val readFormat: Option[String] = { @@ -28,7 +30,7 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For object DataPointer { - def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = { + def from(tableOrPath: String, sparkSession: SparkSession): DataPointer = { CatalogAwareDataPointer(tableOrPath, FormatProvider.from(sparkSession)) diff --git a/spark/src/main/scala/ai/chronon/spark/Extensions.scala b/spark/src/main/scala/ai/chronon/spark/Extensions.scala index bafa888928..303859ebab 100644 --- a/spark/src/main/scala/ai/chronon/spark/Extensions.scala +++ b/spark/src/main/scala/ai/chronon/spark/Extensions.scala @@ -304,28 +304,26 @@ object Extensions { def save(dataPointer: DataPointer): Unit = { + val optionDfw = dfw.options(dataPointer.writeOptions) dataPointer.writeFormat .map((wf) => { val normalized = wf.toLowerCase normalized match { case "bigquery" | "bq" => - dfw + optionDfw .format("bigquery") - .options(dataPointer.options) .save(dataPointer.tableOrPath) case "snowflake" | "sf" => - dfw + optionDfw .format("net.snowflake.spark.snowflake") - .options(dataPointer.options) .option("dbtable", dataPointer.tableOrPath) .save() case "parquet" | "csv" => - dfw + optionDfw .format(normalized) - .options(dataPointer.options) .save(dataPointer.tableOrPath) case "hive" | "delta" | "iceberg" => - dfw + optionDfw .format(normalized) .insertInto(dataPointer.tableOrPath) case _ => @@ -334,7 +332,7 @@ object Extensions { }) .getOrElse( // None case is just table against default catalog - dfw + optionDfw .format("hive") .insertInto(dataPointer.tableOrPath)) } @@ -345,6 +343,8 @@ object Extensions { def load(dataPointer: DataPointer): DataFrame = { val tableOrPath = dataPointer.tableOrPath + val optionDfr = dfr.options(dataPointer.readOptions) + dataPointer.readFormat .map { fmt => val fmtLower = fmt.toLowerCase @@ -352,22 +352,19 @@ object Extensions { fmtLower match { case "bigquery" | "bq" => - dfr + optionDfr .format("bigquery") - .options(dataPointer.options) .load(tableOrPath) case "snowflake" | "sf" => - dfr + optionDfr .format("net.snowflake.spark.snowflake") - .options(dataPointer.options) .option("dbtable", tableOrPath) .load() case "parquet" | "csv" => - dfr - .format(fmtLower) - .options(dataPointer.options) + optionDfr + .format(fmt) .load(tableOrPath) case "hive" | "delta" | "iceberg" => dfr.table(tableOrPath) @@ -379,7 +376,7 @@ object Extensions { } .getOrElse { // None case is just table against default catalog - dfr.table(tableOrPath) + optionDfr.table(tableOrPath) } } } diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index f0df8c0751..b16d87aadd 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -125,7 +125,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable true } catch { case ex: Exception => - logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red} + logger.debug(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red} |Call path: |${cleanStackTrace(ex).yellow} |""".stripMargin) @@ -135,7 +135,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable // Needs provider def loadTable(tableName: String): DataFrame = { - sparkSession.read.load(DataPointer(tableName, sparkSession)) + sparkSession.read.load(DataPointer.from(tableName, sparkSession)) } def isPartitioned(tableName: String): Boolean = { @@ -241,7 +241,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable } def getSchemaFromTable(tableName: String): StructType = { - sparkSession.read.load(DataPointer(tableName, sparkSession)).limit(1).schema + sparkSession.read.load(DataPointer.from(tableName, sparkSession)).limit(1).schema } // method to check if a user has access to a table @@ -254,7 +254,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable // retrieve one row from the table val partitionFilter = lastAvailablePartition(tableName).getOrElse(fallbackPartition) sparkSession.read - .load(DataPointer(tableName, sparkSession)) + .load(DataPointer.from(tableName, sparkSession)) .where(s"$partitionColumn='$partitionFilter'") .limit(1) .collect() @@ -545,8 +545,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable (Seq(partitionColumn, saltCol), Seq(partitionColumn) ++ sortByCols) } else { (Seq(saltCol), sortByCols) } logger.info(s"Sorting within partitions with cols: $partitionSortCols") + val dataPointer = DataPointer.from(tableName, sparkSession) - val dataPointer = DataPointer(tableName, sparkSession) saltedDf .select(saltedDf.columns.map { case c if c == partitionColumn && dataPointer.writeFormat.map(_.toUpperCase).exists("BIGQUERY".equals) => @@ -763,14 +763,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable wheres: Seq[String], rangeWheres: Seq[String], fallbackSelects: Option[Map[String, String]] = None): DataFrame = { - - val dp = DataPointer(table, sparkSession) + val dp = DataPointer.from(table, sparkSession) var df = sparkSession.read.load(dp) val selects = QueryUtils.buildSelects(selectMap, fallbackSelects) logger.info(s""" Scanning data: | table: ${dp.tableOrPath.green} - | options: ${dp.options} + | options: ${dp.readOptions} | format: ${dp.readFormat} | selects: | ${selects.mkString("\n ").green}