Skip to content

Allow setting partition column name in sources #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/src/main/scala/ai/chronon/api/Builders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ object Builders {
timeColumn: String = null,
setups: Seq[String] = null,
mutationTimeColumn: String = null,
reversalColumn: String = null): Query = {
reversalColumn: String = null,
partitionColumn: String = null): Query = {
val result = new Query()
if (selects != null)
result.setSelects(selects.toJava)
Expand All @@ -59,6 +60,7 @@ object Builders {
result.setSetups(setups.toJava)
result.setMutationTimeColumn(mutationTimeColumn)
result.setReversalColumn(reversalColumn)
result.setPartitionColumn(partitionColumn)
result
}
}
Expand Down
6 changes: 6 additions & 0 deletions api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct Query {
6: optional list<string> setups = []
7: optional string mutationTimeColumn
8: optional string reversalColumn
9: optional string partitionColumn
}

/**
Expand Down Expand Up @@ -48,6 +49,11 @@ struct StagingQuery {
* Spark SQL setup statements. Used typically to register UDFs.
**/
4: optional list<string> setups

/**
* Only needed for `max_date` template
**/
5: optional string partitionColumn
}

struct EventSource {
Expand Down
7 changes: 6 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import ai.chronon.api.Window
import ai.chronon.online.PartitionRange
import ai.chronon.online.SparkConversions
import ai.chronon.spark.Driver.parseConf
import ai.chronon.spark.Extensions.QuerySparkOps
import org.apache.datasketches.common.ArrayOfStringsSerDe
import org.apache.datasketches.frequencies.ErrorType
import org.apache.datasketches.frequencies.ItemsSketch
Expand Down Expand Up @@ -90,6 +91,7 @@ class Analyzer(tableUtils: TableUtils,
skewDetection: Boolean = false,
silenceMode: Boolean = false) {

implicit val tu = tableUtils
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
// include ts into heavy hitter analysis - useful to surface timestamps that have wrong units
// include total approx row count - so it is easy to understand the percentage of skewed data
Expand Down Expand Up @@ -311,7 +313,10 @@ class Analyzer(tableUtils: TableUtils,
JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill)
logger.info(s"Join range to fill $rangeToFill")
val unfilledRanges = tableUtils
.unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table)))
.unfilledRanges(joinConf.metaData.outputTable,
rangeToFill,
Some(Seq(joinConf.left.table)),
inputPartitionColumnName = joinConf.left.query.effectivePartitionColumn)
.getOrElse(Seq.empty)

joinConf.joinParts.toScala.foreach { part =>
Expand Down
18 changes: 10 additions & 8 deletions spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ object BootstrapInfo {
tableUtils: TableUtils,
leftSchema: Option[StructType]): BootstrapInfo = {

implicit val tu = tableUtils
implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
// Enrich each join part with the expected output schema
logger.info(s"\nCreating BootstrapInfo for GroupBys for Join ${joinConf.metaData.name}")
Expand Down Expand Up @@ -183,12 +184,11 @@ object BootstrapInfo {
.foreach(part => {
// practically there should only be one logBootstrapPart per Join, but nevertheless we will loop here
val schema = tableUtils.getSchemaFromTable(part.table)
val missingKeys = part.keys(joinConf, tableUtils.partitionColumn).filterNot(schema.fieldNames.contains)
collectException(
assert(
missingKeys.isEmpty,
s"Log table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}"
))
val missingKeys = part.keys(joinConf, part.query.effectivePartitionColumn).filterNot(schema.fieldNames.contains)
collectException(assert(
missingKeys.isEmpty,
s"Log table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}, table schema: ${schema.pretty}"
))
})

// Retrieve schema_hash mapping info from Hive table properties
Expand All @@ -205,13 +205,15 @@ object BootstrapInfo {
.map(part => {
val range = PartitionRange(part.startPartition, part.endPartition)
val bootstrapDf =
tableUtils.scanDf(part.query, part.table, Some(Map(tableUtils.partitionColumn -> null)), Some(range))
tableUtils
.scanDf(part.query, part.table, Some(Map(part.query.effectivePartitionColumn -> null)), range = Some(range))
val schema = bootstrapDf.schema
// We expect partition column and not effectivePartitionColumn because of the scanDf rename
val missingKeys = part.keys(joinConf, tableUtils.partitionColumn).filterNot(schema.fieldNames.contains)
collectException(
assert(
missingKeys.isEmpty,
s"Table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}"
s"Table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}, schema: ${schema.pretty}"
))

collectException(
Expand Down
12 changes: 12 additions & 0 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ai.chronon.spark
import ai.chronon.api
import ai.chronon.api.Constants
import ai.chronon.api.DataPointer
import ai.chronon.api.Extensions.SourceOps
import ai.chronon.api.PartitionSpec
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.online.AvroConversions
Expand Down Expand Up @@ -383,4 +384,15 @@ object Extensions {
}
}
}
implicit class SourceSparkOps(source: api.Source) {

def partitionColumn(implicit tableUtils: TableUtils): String = {
Option(source.query.partitionColumn).getOrElse(tableUtils.partitionColumn)
}
}

implicit class QuerySparkOps(query: api.Query) {
def effectivePartitionColumn(implicit tableUtils: TableUtils): String =
Option(query).flatMap(q => Option(q.partitionColumn)).getOrElse(tableUtils.partitionColumn)
}
}
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/GroupBy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ object GroupBy {

val intersectedRange: PartitionRange = getIntersectedRange(source, queryRange, tableUtils, window)

var metaColumns: Map[String, String] = Map(tableUtils.partitionColumn -> null)
var metaColumns: Map[String, String] = Map(tableUtils.partitionColumn -> source.query.partitionColumn)
if (mutations) {
metaColumns ++= Map(
Constants.ReversalColumn -> source.query.reversalColumn,
Expand Down
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ class Join(joinConf: api.Join,
var bootstrapDf =
tableUtils.scanDf(part.query,
part.table,
Some(Map(tableUtils.partitionColumn -> null)),
Some(Map(part.query.effectivePartitionColumn -> null)),
range = Some(bootstrapRange))

// attach semantic_hash for either log or regular table bootstrap
Expand Down
20 changes: 17 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ abstract class JoinBase(val joinConfCloned: api.Join,
showDf: Boolean = false,
selectedJoinParts: Option[Seq[String]] = None) {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
implicit val tu = tableUtils
private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
assert(Option(joinConfCloned.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null")
val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConfCloned)
Expand Down Expand Up @@ -169,7 +170,8 @@ abstract class JoinBase(val joinConfCloned: api.Join,
inputToOutputShift = shiftDays,
// never skip hole during partTable's range determination logic because we don't want partTable
// and joinTable to be out of sync. skipping behavior is already handled in the outer loop.
skipFirstHole = false
skipFirstHole = false,
inputPartitionColumnName = joinConfCloned.left.query.effectivePartitionColumn
)
.getOrElse(Seq())

Expand Down Expand Up @@ -345,7 +347,13 @@ abstract class JoinBase(val joinConfCloned: api.Join,

(rangeToFill,
tableUtils
.unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConfCloned.left.table)), skipFirstHole = skipFirstHole)
.unfilledRanges(
outputTable,
rangeToFill,
Some(Seq(joinConfCloned.left.table)),
skipFirstHole = skipFirstHole,
inputPartitionColumnName = joinConfCloned.left.query.effectivePartitionColumn
)
.getOrElse(Seq.empty))
}

Expand Down Expand Up @@ -473,7 +481,13 @@ abstract class JoinBase(val joinConfCloned: api.Join,
joinConfCloned.historicalBackfill)
logger.info(s"Join range to fill $rangeToFill")
val unfilledRanges = tableUtils
.unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConfCloned.left.table)), skipFirstHole = skipFirstHole)
.unfilledRanges(
outputTable,
rangeToFill,
Some(Seq(joinConfCloned.left.table)),
skipFirstHole = skipFirstHole,
inputPartitionColumnName = joinConfCloned.left.query.effectivePartitionColumn
)
.getOrElse(Seq.empty)

def finalResult: DataFrame = tableUtils.scanDf(null, outputTable, range = Some(rangeToFill))
Expand Down
21 changes: 10 additions & 11 deletions spark/src/main/scala/ai/chronon/spark/LabelJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
package ai.chronon.spark

import ai.chronon.api
import ai.chronon.api.Constants
import ai.chronon.api.{Builders, Constants, JoinPart, PartitionSpec, TimeUnit, Window}
import ai.chronon.api.DataModel.Entities
import ai.chronon.api.DataModel.Events
import ai.chronon.api.Extensions._
import ai.chronon.api.JoinPart
import ai.chronon.api.PartitionSpec
import ai.chronon.api.TimeUnit
import ai.chronon.api.Window
import ai.chronon.online.Metrics
import ai.chronon.online.PartitionRange
import ai.chronon.spark.Extensions._
Expand Down Expand Up @@ -196,10 +192,13 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
s"${joinConf.metaData.name}/${labelJoinPart.groupBy.getMetaData.getName}")
throw e
}
tableUtils.scanDf(query = null,
partTable,
range = Some(labelOutputRange),
partitionColumn = Constants.LabelPartitionColumn)
// We need to drop the partition column on the scanned DF because label join doesn't expect a second `ds`
// On the right side, which will result in a duplicated column error (scan df renames non-default partition cols)
tableUtils
.scanDf(query = Builders.Query(partitionColumn = Constants.LabelPartitionColumn),
partTable,
range = Some(labelOutputRange))
.drop(tableUtils.partitionColumn)
}
}

Expand Down Expand Up @@ -266,8 +265,8 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
}

// apply key-renaming to key columns
val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) { case (rightDf, (rightKey, leftKey)) =>
rightDf.withColumnRenamed(rightKey, leftKey)
val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) { case (updatedRight, (rightKey, leftKey)) =>
updatedRight.withColumnRenamed(rightKey, leftKey)
}

val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn,
Expand Down
26 changes: 18 additions & 8 deletions spark/src/main/scala/ai/chronon/spark/TableUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
}
}

def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
def partitions(tableName: String,
subPartitionsFilter: Map[String, String] = Map.empty,
partitionColumnName: String = partitionColumn): Seq[String] = {

tableReadFormat(tableName)
.map((format) => {
val partitions = format.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)(sparkSession)
val partitions = format.primaryPartitions(tableName, partitionColumnName, subPartitionsFilter)(sparkSession)

if (partitions.isEmpty) {
logger.info(s"No partitions found for table: $tableName")
Expand Down Expand Up @@ -597,7 +599,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
inputTables: Option[Seq[String]] = None,
inputTableToSubPartitionFiltersMap: Map[String, Map[String, String]] = Map.empty,
inputToOutputShift: Int = 0,
skipFirstHole: Boolean = true): Option[Seq[PartitionRange]] = {
skipFirstHole: Boolean = true,
inputPartitionColumnName: String = partitionColumn): Option[Seq[PartitionRange]] = {

val validPartitionRange = if (outputPartitionRange.start == null) { // determine partition range automatically
val inputStart = inputTables.flatMap(_.map(table =>
Expand Down Expand Up @@ -637,7 +640,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
inputTables <- inputTables.toSeq;
table <- inputTables;
subPartitionFilters = inputTableToSubPartitionFiltersMap.getOrElse(table, Map.empty);
partitionStr <- partitions(table, subPartitionFilters)
partitionStr <- partitions(table, subPartitionFilters, inputPartitionColumnName)
) yield {
partitionSpec.shift(partitionStr, inputToOutputShift)
}
Expand Down Expand Up @@ -828,16 +831,23 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
def scanDf(query: Query,
table: String,
fallbackSelects: Option[Map[String, String]] = None,
range: Option[PartitionRange] = None,
partitionColumn: String = partitionColumn): DataFrame = {
range: Option[PartitionRange] = None): DataFrame = {

val rangeWheres = range.map(whereClauses(_, partitionColumn)).getOrElse(Seq.empty)
val queryPartitionColumn = query.effectivePartitionColumn(this)

val rangeWheres = range.map(whereClauses(_, queryPartitionColumn)).getOrElse(Seq.empty)
val queryWheres = Option(query).flatMap(q => Option(q.wheres)).map(_.toScala).getOrElse(Seq.empty)
val wheres: Seq[String] = rangeWheres ++ queryWheres

val selects = Option(query).flatMap(q => Option(q.selects)).map(_.toScala).getOrElse(Map.empty)

scanDfBase(selects, table, wheres, rangeWheres, fallbackSelects)
val scanDf = scanDfBase(selects, table, wheres, rangeWheres, fallbackSelects)

if (queryPartitionColumn != partitionColumn) {
scanDf.withColumnRenamed(queryPartitionColumn, partitionColumn)
} else {
scanDf
}
}

def partitionRange(table: String): PartitionRange = {
Expand Down
23 changes: 17 additions & 6 deletions spark/src/test/scala/ai/chronon/spark/test/DataFrameGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,37 @@ import scala.collection.Seq
// String types are nulled at row level and also at the set level (some strings are always absent)
object DataFrameGen {
// The main api: that generates dataframes given certain properties of data
def gen(spark: SparkSession, columns: Seq[Column], count: Int): DataFrame = {
def gen(spark: SparkSession, columns: Seq[Column], count: Int, partitionColumn: Option[String] = None): DataFrame = {
val tableUtils = TableUtils(spark)
val RowsWithSchema(rows, schema) = CStream.gen(columns, count, tableUtils.partitionColumn, tableUtils.partitionSpec)
val RowsWithSchema(rows, schema) =
CStream.gen(columns, count, partitionColumn.getOrElse(tableUtils.partitionColumn), tableUtils.partitionSpec)
val genericRows = rows.map { row => new GenericRow(row.fieldsSeq.toArray) }.toArray
val data: RDD[Row] = spark.sparkContext.parallelize(genericRows)
val sparkSchema = SparkConversions.fromChrononSchema(schema)
spark.createDataFrame(data, sparkSchema)
}

// The main api: that generates dataframes given certain properties of data
def events(spark: SparkSession, columns: Seq[Column], count: Int, partitions: Int): DataFrame = {
def events(spark: SparkSession,
columns: Seq[Column],
count: Int,
partitions: Int,
partitionColumn: Option[String] = None): DataFrame = {
val partitionColumnString = partitionColumn.getOrElse(TableUtils(spark).partitionColumn)
val generated = gen(spark, columns :+ Column(Constants.TimeColumn, LongType, partitions), count)
generated.withColumn(
TableUtils(spark).partitionColumn,
partitionColumnString,
from_unixtime(generated.col(Constants.TimeColumn) / 1000, TableUtils(spark).partitionSpec.format))
}

// Generates Entity data
def entities(spark: SparkSession, columns: Seq[Column], count: Int, partitions: Int): DataFrame = {
gen(spark, columns :+ Column(TableUtils(spark).partitionColumn, StringType, partitions), count)
def entities(spark: SparkSession,
columns: Seq[Column],
count: Int,
partitions: Int,
partitionColumn: Option[String] = None): DataFrame = {
val partitionColumnString = partitionColumn.getOrElse(TableUtils(spark).partitionColumn)
gen(spark, columns :+ Column(partitionColumnString, StringType, partitions), count, partitionColumn)
}

/** Mutations and snapshots generation.
Expand Down
Loading