Skip to content

fix: handle partition overwrite #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .plugin-versions
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
asdf-plugin-manager https://github.com/asdf-community/asdf-plugin-manager.git b5862c1
gcloud https://github.com/jthegedus/asdf-gcloud.git 00cdf06
java https://github.com/halcyon/asdf-java.git 0ec69b2
python https://github.com/danhper/asdf-python.git a3a0185
sbt https://github.com/lerencao/asdf-sbt 53c9f4b
scala https://github.com/asdf-community/asdf-scala.git 0533444
thrift https://github.com/alisaifee/asdf-thrift.git fecdd6c
1 change: 1 addition & 0 deletions .tool-versions
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ sbt 1.8.2
python
3.7.17
3.11.0
gcloud 504.0.1
12 changes: 9 additions & 3 deletions api/src/main/scala/ai/chronon/api/DataPointer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ abstract class DataPointer {
def tableOrPath: String
def readFormat: Option[String]
def writeFormat: Option[String]
def options: Map[String, String]

def readOptions: Map[String, String]
def writeOptions: Map[String, String]

}

case class URIDataPointer(
override val tableOrPath: String,
override val readFormat: Option[String],
override val writeFormat: Option[String],
override val options: Map[String, String]
) extends DataPointer
options: Map[String, String]
) extends DataPointer {

override val readOptions: Map[String, String] = options
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for URIDataPointer, just treat the options the same.

override val writeOptions: Map[String, String] = options
}

// parses string representations of data pointers
// ex: namespace.table
Expand Down
3 changes: 1 addition & 2 deletions api/src/test/scala/ai/chronon/api/test/DataPointerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import ai.chronon.api.URIDataPointer
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class
DataPointerTest extends AnyFlatSpec with Matchers {
class DataPointerTest extends AnyFlatSpec with Matchers {

"DataPointer.apply" should "parse a simple s3 path" in {
val result = DataPointer("s3://bucket/path/to/data.parquet")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,21 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
override def readFormat(tableName: String): Format = format(tableName)

override def writeFormat(table: String): Format = {
val tableId = BigQueryUtil.parseTableId(table)
assert(scala.Option(tableId.getProject).isDefined, s"project required for ${table}")
assert(scala.Option(tableId.getDataset).isDefined, s"dataset required for ${table}")

val tu = TableUtils(sparkSession)
val partitionColumnOption =
if (tu.tableReachable(table)) Map.empty else Map("partitionField" -> tu.partitionColumn)

val sparkOptions: Map[String, String] = Map(
"partitionField" -> tu.partitionColumn,
// todo(tchow): No longer needed after https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/1320
"temporaryGcsBucket" -> sparkSession.conf.get("spark.chronon.table.gcs.temporary_gcs_bucket"),
"writeMethod" -> "indirect"
)
) ++ partitionColumnOption

BigQueryFormat(bqOptions.getProjectId, sparkOptions)
BigQueryFormat(tableId.getProject, sparkOptions)
}

private def getFormat(table: Table): Format =
Expand Down Expand Up @@ -72,7 +76,10 @@ case class GcpFormatProvider(sparkSession: SparkSession) extends FormatProvider
val table = bigQueryClient.getTable(btTableIdentifier.getDataset, btTableIdentifier.getTable)

// lookup bq for the table, if not fall back to hive
scala.Option(table).map(getFormat).getOrElse(Hive)
scala
.Option(table)
.map(getFormat)
.getOrElse(scala.Option(btTableIdentifier.getProject).map(BigQueryFormat(_, Map.empty)).getOrElse(Hive))

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For
formatProvider.resolveTableName(inputTableOrPath)
}

override lazy val options: Map[String, String] = {
// Hack for now, include both read and write options for the datapointer.
// todo(tchow): rework this abstraction. https://app.asana.com/0/1208785567265389/1209026103291854/f
formatProvider.readFormat(inputTableOrPath).options ++ formatProvider.writeFormat(inputTableOrPath).options
override lazy val readOptions: Map[String, String] = {
formatProvider.readFormat(inputTableOrPath).options
}

override lazy val writeOptions: Map[String, String] = {
formatProvider.writeFormat(inputTableOrPath).options
}

override lazy val readFormat: Option[String] = {
Expand All @@ -28,7 +30,7 @@ case class CatalogAwareDataPointer(inputTableOrPath: String, formatProvider: For

object DataPointer {

def apply(tableOrPath: String, sparkSession: SparkSession): DataPointer = {
def from(tableOrPath: String, sparkSession: SparkSession): DataPointer = {

CatalogAwareDataPointer(tableOrPath, FormatProvider.from(sparkSession))

Expand Down
29 changes: 13 additions & 16 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,28 +304,26 @@ object Extensions {

def save(dataPointer: DataPointer): Unit = {

val optionDfw = dfw.options(dataPointer.writeOptions)
dataPointer.writeFormat
.map((wf) => {
val normalized = wf.toLowerCase
normalized match {
case "bigquery" | "bq" =>
dfw
optionDfw
.format("bigquery")
.options(dataPointer.options)
.save(dataPointer.tableOrPath)
case "snowflake" | "sf" =>
dfw
optionDfw
.format("net.snowflake.spark.snowflake")
.options(dataPointer.options)
.option("dbtable", dataPointer.tableOrPath)
.save()
case "parquet" | "csv" =>
dfw
optionDfw
.format(normalized)
.options(dataPointer.options)
.save(dataPointer.tableOrPath)
case "hive" | "delta" | "iceberg" =>
dfw
optionDfw
.format(normalized)
.insertInto(dataPointer.tableOrPath)
case _ =>
Expand All @@ -334,7 +332,7 @@ object Extensions {
})
.getOrElse(
// None case is just table against default catalog
dfw
optionDfw
.format("hive")
.insertInto(dataPointer.tableOrPath))
}
Expand All @@ -345,29 +343,28 @@ object Extensions {
def load(dataPointer: DataPointer): DataFrame = {
val tableOrPath = dataPointer.tableOrPath

val optionDfr = dfr.options(dataPointer.readOptions)

dataPointer.readFormat
.map { fmt =>
val fmtLower = fmt.toLowerCase

fmtLower match {

case "bigquery" | "bq" =>
dfr
optionDfr
.format("bigquery")
.options(dataPointer.options)
.load(tableOrPath)

case "snowflake" | "sf" =>
dfr
optionDfr
.format("net.snowflake.spark.snowflake")
.options(dataPointer.options)
.option("dbtable", tableOrPath)
.load()

case "parquet" | "csv" =>
dfr
.format(fmtLower)
.options(dataPointer.options)
optionDfr
.format(fmt)
.load(tableOrPath)

case "hive" | "delta" | "iceberg" => dfr.table(tableOrPath)
Expand All @@ -379,7 +376,7 @@ object Extensions {
}
.getOrElse {
// None case is just table against default catalog
dfr.table(tableOrPath)
optionDfr.table(tableOrPath)
}
}
}
Expand Down
15 changes: 7 additions & 8 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
true
} catch {
case ex: Exception =>
logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
logger.debug(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
|Call path:
|${cleanStackTrace(ex).yellow}
|""".stripMargin)
Expand All @@ -135,7 +135,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable

// Needs provider
def loadTable(tableName: String): DataFrame = {
sparkSession.read.load(DataPointer(tableName, sparkSession))
sparkSession.read.load(DataPointer.from(tableName, sparkSession))
}

def isPartitioned(tableName: String): Boolean = {
Expand Down Expand Up @@ -241,7 +241,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
}

def getSchemaFromTable(tableName: String): StructType = {
sparkSession.read.load(DataPointer(tableName, sparkSession)).limit(1).schema
sparkSession.read.load(DataPointer.from(tableName, sparkSession)).limit(1).schema
}

// method to check if a user has access to a table
Expand All @@ -254,7 +254,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
// retrieve one row from the table
val partitionFilter = lastAvailablePartition(tableName).getOrElse(fallbackPartition)
sparkSession.read
.load(DataPointer(tableName, sparkSession))
.load(DataPointer.from(tableName, sparkSession))
.where(s"$partitionColumn='$partitionFilter'")
.limit(1)
.collect()
Expand Down Expand Up @@ -545,8 +545,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
(Seq(partitionColumn, saltCol), Seq(partitionColumn) ++ sortByCols)
} else { (Seq(saltCol), sortByCols) }
logger.info(s"Sorting within partitions with cols: $partitionSortCols")
val dataPointer = DataPointer.from(tableName, sparkSession)

val dataPointer = DataPointer(tableName, sparkSession)
saltedDf
.select(saltedDf.columns.map {
case c if c == partitionColumn && dataPointer.writeFormat.map(_.toUpperCase).exists("BIGQUERY".equals) =>
Expand Down Expand Up @@ -763,14 +763,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
wheres: Seq[String],
rangeWheres: Seq[String],
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {

val dp = DataPointer(table, sparkSession)
val dp = DataPointer.from(table, sparkSession)
var df = sparkSession.read.load(dp)
val selects = QueryUtils.buildSelects(selectMap, fallbackSelects)

logger.info(s""" Scanning data:
| table: ${dp.tableOrPath.green}
| options: ${dp.options}
| options: ${dp.readOptions}
| format: ${dp.readFormat}
| selects:
| ${selects.mkString("\n ").green}
Expand Down
Loading