Skip to content

feat: add col to partition-spec #731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 7, 2025
21 changes: 10 additions & 11 deletions aggregator/src/test/scala/ai/chronon/aggregator/test/DataGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,17 @@ object CStream {
}

// The main api: that generates dataframes given certain properties of data
def gen(columns: Seq[Column],
count: Int,
partitionColumn: String = null,
partitionSpec: PartitionSpec = null): RowsWithSchema = {
def gen(columns: Seq[Column], count: Int, partitionSpec: PartitionSpec = null): RowsWithSchema = {
val schema = columns.map(_.schema)
val generators = columns.map(_.gen(partitionColumn, partitionSpec))
val generators = columns.map(_.gen(partitionSpec))
val zippedStream = new ZippedStream(generators.toSeq: _*)(schema.indexWhere(_._1 == Constants.TimeColumn))
RowsWithSchema(Seq.fill(count) { zippedStream.next() }.toArray, schema)
}
}

case class Column(name: String, `type`: DataType, cardinality: Int, chunkSize: Int = 10, nullRate: Double = 0.1) {
def genImpl(dtype: DataType, partitionColumn: String, partitionSpec: PartitionSpec, nullRate: Double): CStream[Any] =
def genImpl(dtype: DataType, partitionSpec: PartitionSpec, nullRate: Double): CStream[Any] = {
val partitionColumn = Option(partitionSpec).map(_.column).orNull
dtype match {
case StringType =>
name match {
Expand All @@ -191,16 +189,17 @@ case class Column(name: String, `type`: DataType, cardinality: Int, chunkSize: I
case _ => new LongStream(cardinality, nullRate)
}
case ListType(elementType) =>
genImpl(elementType, partitionColumn, partitionSpec, nullRate).chunk(chunkSize)
genImpl(elementType, partitionSpec, nullRate).chunk(chunkSize)
case MapType(keyType, valueType) =>
val keyStream = genImpl(keyType, partitionColumn, partitionSpec, 0)
val valueStream = genImpl(valueType, partitionColumn, partitionSpec, nullRate)
val keyStream = genImpl(keyType, partitionSpec, 0)
val valueStream = genImpl(valueType, partitionSpec, nullRate)
keyStream.zipChunk(valueStream, maxSize = chunkSize)
case otherType => throw new UnsupportedOperationException(s"Can't generate random data for $otherType yet.")
}
}

def gen(partitionColumn: String, partitionSpec: PartitionSpec): CStream[Any] =
genImpl(`type`, partitionColumn, partitionSpec, nullRate)
def gen(partitionSpec: PartitionSpec): CStream[Any] =
genImpl(`type`, partitionSpec, nullRate)
def schema: (String, DataType) = name -> `type`
}
case class RowsWithSchema(rows: Array[TestRow], schema: Seq[(String, DataType)])
9 changes: 5 additions & 4 deletions api/src/main/scala/ai/chronon/api/DataRange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@ case class PartitionRange(start: String, end: String)(implicit val partitionSpec
.toArray
}

def betweenClauses(partitionColumn: String): String = {
s"$partitionColumn BETWEEN '$start' AND '$end'"
def betweenClauses: String = {
s"${partitionSpec.column} BETWEEN '$start' AND '$end'"
}

def whereClauses(partitionColumn: String): Seq[String] = {
(Option(start).map(s => s"$partitionColumn >= '$s'") ++ Option(end).map(e => s"$partitionColumn <= '$e'")).toSeq
def whereClauses: Seq[String] = {
(Option(start).map(s => s"${partitionSpec.column} >= '$s'") ++ Option(end).map(e =>
s"${partitionSpec.column} <= '$e'")).toSeq
}

def steps(days: Int): Seq[PartitionRange] = {
Expand Down
27 changes: 5 additions & 22 deletions api/src/main/scala/ai/chronon/api/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import ai.chronon.api
import ai.chronon.api.Constants._
import ai.chronon.api.DataModel._
import ai.chronon.api.Operation._
import ai.chronon.api.QueryUtils.buildSelects
import ai.chronon.api.ScalaJavaConversions._
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.expr
Expand Down Expand Up @@ -1133,27 +1132,11 @@ object Extensions {
result
}

// mutationsOnSnapshot table appends default values for mutation_ts and is_before column on the snapshotTable
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

// otherwise we will populate the query with the actual mutation_ts and is_before expressions specified in the query
def baseQuery(mutationInfoOnSnapshot: Boolean = false): String = {

val selects = enrichedSelects(mutationInfoOnSnapshot)
val wheres = query.wheres.toScala

val finalSelects = buildSelects(selects, None)

val whereClause = Option(wheres)
.filter(_.nonEmpty)
.map { ws =>
s"""
|WHERE
| ${ws.map(w => s"(${w})").mkString(" AND ")}""".stripMargin
}
.getOrElse("")

s"""SELECT
| ${finalSelects.mkString(",\n ")}
|$whereClause""".stripMargin
def partitionSpec(defaultSpec: PartitionSpec): PartitionSpec = {
val column = Option(query.partitionColumn).getOrElse(defaultSpec.column)
val format = Option(query.partitionFormat).getOrElse(defaultSpec.format)
val interval = Option(query.partitionInterval).getOrElse(WindowUtils.Day)
PartitionSpec(column, format, interval.millis)
}
}

Expand Down
6 changes: 2 additions & 4 deletions api/src/main/scala/ai/chronon/api/PartitionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.time.format.DateTimeFormatter
import java.util.Locale
import java.util.TimeZone

case class PartitionSpec(format: String, spanMillis: Long) {
case class PartitionSpec(column: String, format: String, spanMillis: Long) {

private def partitionFormatter =
DateTimeFormatter
Expand Down Expand Up @@ -83,7 +83,5 @@ case class PartitionSpec(format: String, spanMillis: Long) {
}

object PartitionSpec {
val daily: PartitionSpec = PartitionSpec("yyyy-MM-dd", 24 * 60 * 60 * 1000)
val hourly: PartitionSpec = PartitionSpec("yyyy-MM-dd-HH", 60 * 60 * 1000)
val fifteenMinutes: PartitionSpec = PartitionSpec("yyyy-MM-dd-HH-mm", 15 * 60 * 1000)
val daily: PartitionSpec = PartitionSpec("ds", "yyyy-MM-dd", 24 * 60 * 60 * 1000)
}
2 changes: 1 addition & 1 deletion api/src/test/scala/ai/chronon/api/test/DateMacroSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.scalatest.matchers.should.Matchers

class DateMacroSpec extends AnyFlatSpec with Matchers {

private val partitionSpec = PartitionSpec("yyyy-MM-dd", 24 * 3600 * 1000)
private val partitionSpec = PartitionSpec.daily

// Tests for remoteQuotesIfPresent
"remoteQuotesIfPresent" should "remove single quotes from the beginning and end of a string" in {
Expand Down
2 changes: 2 additions & 0 deletions api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ struct Query {
7: optional string mutationTimeColumn
8: optional string reversalColumn
9: optional string partitionColumn
10: optional string partitionFormat
11: optional common.Window partitionInterval
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,

// we use the endDs + span to indicate the timestamp of all the cell data we upload for endDs
// this is used in the KV store multiget calls
val partitionSpec = PartitionSpec("yyyy-MM-dd", WindowUtils.Day.millis)
val partitionSpec = PartitionSpec("ds", "yyyy-MM-dd", WindowUtils.Day.millis)
val endDsPlusOne = partitionSpec.epochMillis(partition) + partitionSpec.spanMillis

// we need to sanitize and append the batch suffix to the groupBy name as that's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,7 @@ object FlinkTestUtils {
.toAvroSchema("Value")
.toString(true)
)
new GroupByServingInfoParsed(
groupByServingInfo,
PartitionSpec(format = "yyyy-MM-dd", spanMillis = WindowUtils.Day.millis)
)
new GroupByServingInfoParsed(groupByServingInfo)
}

def makeGroupBy(keyColumns: Seq[String], filters: Seq[String] = Seq.empty): GroupBy =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package ai.chronon.online

import ai.chronon.aggregator.windowing.{ResolutionUtils, SawtoothOnlineAggregator}
import ai.chronon.api.Constants.{ReversalField, TimeField}
import ai.chronon.api.Extensions.{GroupByOps, MetadataOps}
import ai.chronon.api.Extensions.{GroupByOps, MetadataOps, WindowOps, WindowUtils}
import ai.chronon.api.ScalaJavaConversions.ListOps
import ai.chronon.api._
import ai.chronon.online.OnlineDerivationUtil.{DerivationFunc, buildDerivationFunction}
Expand All @@ -28,10 +28,13 @@ import org.apache.avro.Schema
import scala.collection.Seq

// mixin class - with schema
class GroupByServingInfoParsed(val groupByServingInfo: GroupByServingInfo, partitionSpec: PartitionSpec)
class GroupByServingInfoParsed(val groupByServingInfo: GroupByServingInfo)
extends GroupByServingInfo(groupByServingInfo)
with Serializable {

// the is not really used - we just need the format
private val partitionSpec = PartitionSpec("ds", groupByServingInfo.dateFormat, WindowUtils.Day.millis)

// streaming starts scanning after batchEnd
lazy val batchEndTsMillis: Long = partitionSpec.epochMillis(batchEndDate)
private def parser = new Schema.Parser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,8 @@ case class ConfPathOrName(confPath: Option[String] = None, confName: Option[Stri
class MetadataStore(fetchContext: FetchContext) {

@transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass)
private var partitionSpec =
PartitionSpec(format = "yyyy-MM-dd", spanMillis = WindowUtils.Day.millis)
private val CONF_BATCH_SIZE = 50

// Note this should match with the format used in the warehouse
def setPartitionMeta(format: String, spanMillis: Long): Unit = {
partitionSpec = PartitionSpec(format = format, spanMillis = spanMillis)
}

// Note this should match with the format used in the warehouse
def setPartitionMeta(format: String): Unit = {
partitionSpec = PartitionSpec(format = format, spanMillis = partitionSpec.spanMillis)
}

implicit val executionContext: ExecutionContext = fetchContext.getOrCreateExecutionContext

def getConf[T <: TBase[_, _]: Manifest](confPathOrName: ConfPathOrName): Try[T] = {
Expand Down Expand Up @@ -411,7 +399,7 @@ class MetadataStore(fetchContext: FetchContext) {
.Context(metrics.Metrics.Environment.MetaDataFetching, groupByServingInfo.groupBy)
.withSuffix("group_by")
.distribution(metrics.Metrics.Name.LatencyMillis, System.currentTimeMillis() - startTimeMs)
Success(new GroupByServingInfoParsed(groupByServingInfo, partitionSpec))
Success(new GroupByServingInfoParsed(groupByServingInfo))
}
},
{ gb =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.scalatest.matchers.should.Matchers
class DataRangeTest extends AnyFlatSpec with Matchers {

// Assuming you have a PartitionSpec and PartitionRange class defined somewhere
implicit val partitionSpec: PartitionSpec = new PartitionSpec("yyyy-MM-dd", new Window(1, TimeUnit.DAYS).millis)
implicit val partitionSpec: PartitionSpec = new PartitionSpec("ds", "yyyy-MM-dd", new Window(1, TimeUnit.DAYS).millis)

"collapseToRange" should "collapse consecutive partitions into ranges" in {
val partitions = List(
Expand Down
6 changes: 5 additions & 1 deletion spark/src/main/scala/ai/chronon/spark/GroupBy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -638,18 +638,22 @@ object GroupBy {
Constants.MutationTimeColumn -> source.query.mutationTimeColumn
)
}

val sourcePartitionSpec = source.query.partitionSpec(tableUtils.partitionSpec)

val timeMapping = if (source.dataModel == ENTITIES) {
Option(source.query.timeColumn).map(Constants.TimeColumn -> _)
} else {
if (accuracy == api.Accuracy.TEMPORAL) {
Some(Constants.TimeColumn -> source.query.timeColumn)
} else {
val dsBasedTimestamp = // 1 millisecond before ds + 1
s"(((UNIX_TIMESTAMP(${tableUtils.partitionColumn}, '${tableUtils.partitionSpec.format}') + 86400) * 1000) - 1)"
s"(((UNIX_TIMESTAMP(${sourcePartitionSpec.column}, '${sourcePartitionSpec.format}') + 86400) * 1000) - 1)"

Some(Constants.TimeColumn -> Option(source.query.timeColumn).getOrElse(dsBasedTimestamp))
}
}

logger.info(s"""
|Time Mapping: $timeMapping
|""".stripMargin)
Expand Down
18 changes: 14 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ object GroupByUpload {

val groupBy = ai.chronon.spark.GroupBy
.from(groupByConf, PartitionRange(endDs, endDs), TableUtils(session), computeDependency = false)

groupByServingInfo.setBatchEndDate(nextDay)
groupByServingInfo.setGroupBy(groupByConf)
groupByServingInfo.setKeyAvroSchema(groupBy.keySchema.toAvroSchema("Key").toString(true))
groupByServingInfo.setSelectedAvroSchema(groupBy.preAggSchema.toAvroSchema("Value").toString(true))
groupByServingInfo.setDateFormat(tableUtils.partitionFormat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this come from tableUtils.partitionSpec ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its the same thing essentially


if (groupByConf.streamingSource.isDefined) {
val streamingSource = groupByConf.streamingSource.get

Expand Down Expand Up @@ -181,7 +184,7 @@ object GroupByUpload {
logger.info("Not setting InputAvroSchema to GroupByServingInfo as there is no streaming source defined.")
}

val result = new GroupByServingInfoParsed(groupByServingInfo, tableUtils.partitionSpec)
val result = new GroupByServingInfoParsed(groupByServingInfo)
val firstSource = groupByConf.sources.get(0)
logger.info(s"""
|Built GroupByServingInfo for ${groupByConf.metaData.name}:
Expand Down Expand Up @@ -259,6 +262,7 @@ object GroupByUpload {
))
val metaRdd = tableUtils.sparkSession.sparkContext.parallelize(metaRows.toSeq)
val metaDf = tableUtils.sparkSession.createDataFrame(metaRdd, kvDf.schema)

kvDf
.union(metaDf)
.withColumn("ds", lit(endDs))
Expand All @@ -270,9 +274,15 @@ object GroupByUpload {

val metricRow =
kvDfReloaded.selectExpr("sum(bit_length(key_bytes))/8", "sum(bit_length(value_bytes))/8", "count(*)").collect()
context.gauge(Metrics.Name.KeyBytes, metricRow(0).getDouble(0).toLong)
context.gauge(Metrics.Name.ValueBytes, metricRow(0).getDouble(1).toLong)
context.gauge(Metrics.Name.RowCount, metricRow(0).getLong(2))

if (metricRow.length > 0) {
context.gauge(Metrics.Name.KeyBytes, metricRow(0).getDouble(0).toLong)
context.gauge(Metrics.Name.ValueBytes, metricRow(0).getDouble(1).toLong)
context.gauge(Metrics.Name.RowCount, metricRow(0).getLong(2))
} else {
throw new RuntimeException("GroupBy upload resulted in zero rows.")
}

context.gauge(Metrics.Name.LatencyMinutes, (System.currentTimeMillis() - startTs) / (60 * 1000))
}
}
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class Join(joinConf: api.Join,
} else {
leftRange
}
val wheres = effectiveRange.whereClauses("ds")
val wheres = effectiveRange.whereClauses
val sql = QueryUtils.build(null, partTable, wheres)
logger.info(s"Pulling data from joinPart table with: $sql")
(joinPart, tableUtils.scanDfBase(null, partTable, List.empty, wheres, None))
Expand Down
2 changes: 1 addition & 1 deletion spark/src/main/scala/ai/chronon/spark/batch/MergeJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MergeJob(node: JoinMergeNode, range: DateRange, joinParts: Seq[JoinPart])(
} else {
dayStep
}
val wheres = effectiveRange.whereClauses(tableUtils.partitionColumn)
val wheres = effectiveRange.whereClauses
val sql = QueryUtils.build(null, partTable, wheres)
logger.info(s"Pulling data from joinPart table with: $sql")
(joinPart, tableUtils.scanDfBase(null, partTable, List.empty, wheres, None))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
sparkSession.conf.get("spark.chronon.partition.column", "ds")
val partitionFormat: String =
sparkSession.conf.get("spark.chronon.partition.format", "yyyy-MM-dd")
val partitionSpec: PartitionSpec = PartitionSpec(partitionFormat, WindowUtils.Day.millis)
val partitionSpec: PartitionSpec = PartitionSpec(partitionColumn, partitionFormat, WindowUtils.Day.millis)
val smallModelEnabled: Boolean =
sparkSession.conf.get("spark.chronon.backfill.small_mode.enabled", "true").toBoolean
val smallModeNumRowsCutoff: Int =
Expand Down
9 changes: 5 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class CompareJob(
val leftDf = tableUtils.sql(s"""
|SELECT *
|FROM ${joinConf.metaData.outputTable}
|WHERE ${partitionRange.betweenClauses(partitionColumn = tableUtils.partitionColumn)}
|WHERE ${partitionRange.betweenClauses}
|""".stripMargin)

// Run the staging query sql directly
Expand Down Expand Up @@ -169,11 +169,12 @@ object CompareJob {
consolidatedData
}

def getJoinKeys(joinConf: api.Join, tableUtils: TableUtils): Seq[String] = {
def getJoinKeys(joinConf: api.Join, tableUtils: TableUtils): Array[String] = {
if (joinConf.isSetRowIds) {
joinConf.rowIds.toScala
joinConf.rowIds.toScala.toArray
} else {
val keyCols = joinConf.leftKeyCols ++ Seq(tableUtils.partitionColumn)
val leftPartitionCol = joinConf.left.query.partitionSpec(tableUtils.partitionSpec).column
val keyCols = joinConf.leftKeyCols :+ leftPartitionCol
if (joinConf.left.dataModel == EVENTS) {
keyCols ++ Seq(Constants.TimeColumn)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class PartitionRunner[T](verb: String,
|input: $inputTable (${inputRange.start} -> ${inputRange.end})
|output: $outputTable (${outputRange.start} -> ${outputRange.end})
|""".stripMargin.yellow)
val inputFilter = inputRange.whereClauses(tu.partitionColumn).mkString(" AND ")
val inputFilter = inputRange.whereClauses.mkString(" AND ")
val inputDf = tu.loadTable(inputTable).filter(inputFilter)
val (outputDf, sideVal) = computeFunc(inputDf)
side = Option(sideVal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ object DataFrameGen {
// The main api: that generates dataframes given certain properties of data
def gen(spark: SparkSession, columns: Seq[Column], count: Int, partitionColumn: Option[String] = None): DataFrame = {
val tableUtils = TableUtils(spark)
val effectiveSpec =
tableUtils.partitionSpec.copy(column = partitionColumn.getOrElse(tableUtils.partitionSpec.column))
val RowsWithSchema(rows, schema) =
CStream.gen(columns, count, partitionColumn.getOrElse(tableUtils.partitionColumn), tableUtils.partitionSpec)
CStream.gen(columns, count, effectiveSpec)
val genericRows = rows.map { row => new GenericRow(row.fieldsSeq.toArray) }.toArray
val data: RDD[Row] = spark.sparkContext.parallelize(genericRows)
val sparkSchema = SparkConversions.fromChrononSchema(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {

before {
when(mockTableUtils.partitionColumn).thenReturn("ds")
when(mockTableUtils.partitionSpec).thenReturn(PartitionSpec("yyyy-MM-dd", WindowUtils.Day.millis))
when(mockTableUtils.partitionSpec).thenReturn(PartitionSpec("ds", "yyyy-MM-dd", WindowUtils.Day.millis))
}

class TestArgs(args: Array[String]) extends ScallopConf(args) with OfflineSubcommand with ResultValidationAbility {
Expand Down
Loading