Skip to content

Commit 886be09

Browse files
varant-zlaiezvznikhil-zlai
authored
Allow setting partition column name in sources (#381)
## Summary Allow setting partition column name in sources. Maps it to the default partition name upon read and partition checking. ## Checklist - [x] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enabled configurable partition columns in query, join, and data generation operations for improved data partitioning. - **Refactor** - Streamlined partition handling and consolidated import structures to enhance workflow efficiency. - **Tests** - Added test cases for verifying partition column functionality and adjusted data generation volumes for better validation. - Introduced new tests specifically for different partition columns to ensure accurate handling of partitioned data. These enhancements provide increased flexibility and accuracy in managing partitioned datasets during data processing and join operations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: ezvz <[email protected]> Co-authored-by: Nikhil Simha <[email protected]>
1 parent c36b7d2 commit 886be09

File tree

13 files changed

+271
-78
lines changed

13 files changed

+271
-78
lines changed

api/src/main/scala/ai/chronon/api/Builders.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ object Builders {
4646
timeColumn: String = null,
4747
setups: Seq[String] = null,
4848
mutationTimeColumn: String = null,
49-
reversalColumn: String = null): Query = {
49+
reversalColumn: String = null,
50+
partitionColumn: String = null): Query = {
5051
val result = new Query()
5152
if (selects != null)
5253
result.setSelects(selects.toJava)
@@ -59,6 +60,7 @@ object Builders {
5960
result.setSetups(setups.toJava)
6061
result.setMutationTimeColumn(mutationTimeColumn)
6162
result.setReversalColumn(reversalColumn)
63+
result.setPartitionColumn(partitionColumn)
6264
result
6365
}
6466
}

api/thrift/api.thrift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct Query {
1616
6: optional list<string> setups = []
1717
7: optional string mutationTimeColumn
1818
8: optional string reversalColumn
19+
9: optional string partitionColumn
1920
}
2021

2122
/**
@@ -48,6 +49,11 @@ struct StagingQuery {
4849
* Spark SQL setup statements. Used typically to register UDFs.
4950
**/
5051
4: optional list<string> setups
52+
53+
/**
54+
* Only needed for `max_date` template
55+
**/
56+
5: optional string partitionColumn
5157
}
5258

5359
struct EventSource {

spark/src/main/scala/ai/chronon/spark/Analyzer.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import ai.chronon.api.Window
3232
import ai.chronon.online.PartitionRange
3333
import ai.chronon.online.SparkConversions
3434
import ai.chronon.spark.Driver.parseConf
35+
import ai.chronon.spark.Extensions.QuerySparkOps
3536
import org.apache.datasketches.common.ArrayOfStringsSerDe
3637
import org.apache.datasketches.frequencies.ErrorType
3738
import org.apache.datasketches.frequencies.ItemsSketch
@@ -90,6 +91,7 @@ class Analyzer(tableUtils: TableUtils,
9091
skewDetection: Boolean = false,
9192
silenceMode: Boolean = false) {
9293

94+
implicit val tu = tableUtils
9395
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
9496
// include ts into heavy hitter analysis - useful to surface timestamps that have wrong units
9597
// include total approx row count - so it is easy to understand the percentage of skewed data
@@ -311,7 +313,10 @@ class Analyzer(tableUtils: TableUtils,
311313
JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill)
312314
logger.info(s"Join range to fill $rangeToFill")
313315
val unfilledRanges = tableUtils
314-
.unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table)))
316+
.unfilledRanges(joinConf.metaData.outputTable,
317+
rangeToFill,
318+
Some(Seq(joinConf.left.table)),
319+
inputPartitionColumnName = joinConf.left.query.effectivePartitionColumn)
315320
.getOrElse(Seq.empty)
316321

317322
joinConf.joinParts.toScala.foreach { part =>

spark/src/main/scala/ai/chronon/spark/BootstrapInfo.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ object BootstrapInfo {
8686
tableUtils: TableUtils,
8787
leftSchema: Option[StructType]): BootstrapInfo = {
8888

89+
implicit val tu = tableUtils
8990
implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
9091
// Enrich each join part with the expected output schema
9192
logger.info(s"\nCreating BootstrapInfo for GroupBys for Join ${joinConf.metaData.name}")
@@ -183,12 +184,11 @@ object BootstrapInfo {
183184
.foreach(part => {
184185
// practically there should only be one logBootstrapPart per Join, but nevertheless we will loop here
185186
val schema = tableUtils.getSchemaFromTable(part.table)
186-
val missingKeys = part.keys(joinConf, tableUtils.partitionColumn).filterNot(schema.fieldNames.contains)
187-
collectException(
188-
assert(
189-
missingKeys.isEmpty,
190-
s"Log table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}"
191-
))
187+
val missingKeys = part.keys(joinConf, part.query.effectivePartitionColumn).filterNot(schema.fieldNames.contains)
188+
collectException(assert(
189+
missingKeys.isEmpty,
190+
s"Log table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}, table schema: ${schema.pretty}"
191+
))
192192
})
193193

194194
// Retrieve schema_hash mapping info from Hive table properties
@@ -205,13 +205,15 @@ object BootstrapInfo {
205205
.map(part => {
206206
val range = PartitionRange(part.startPartition, part.endPartition)
207207
val bootstrapDf =
208-
tableUtils.scanDf(part.query, part.table, Some(Map(tableUtils.partitionColumn -> null)), Some(range))
208+
tableUtils
209+
.scanDf(part.query, part.table, Some(Map(part.query.effectivePartitionColumn -> null)), range = Some(range))
209210
val schema = bootstrapDf.schema
211+
// We expect partition column and not effectivePartitionColumn because of the scanDf rename
210212
val missingKeys = part.keys(joinConf, tableUtils.partitionColumn).filterNot(schema.fieldNames.contains)
211213
collectException(
212214
assert(
213215
missingKeys.isEmpty,
214-
s"Table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}"
216+
s"Table ${part.table} does not contain some specified keys: ${missingKeys.prettyInline}, schema: ${schema.pretty}"
215217
))
216218

217219
collectException(

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package ai.chronon.spark
1919
import ai.chronon.api
2020
import ai.chronon.api.Constants
2121
import ai.chronon.api.DataPointer
22+
import ai.chronon.api.Extensions.SourceOps
2223
import ai.chronon.api.PartitionSpec
2324
import ai.chronon.api.ScalaJavaConversions._
2425
import ai.chronon.online.AvroConversions
@@ -383,4 +384,15 @@ object Extensions {
383384
}
384385
}
385386
}
387+
implicit class SourceSparkOps(source: api.Source) {
388+
389+
def partitionColumn(implicit tableUtils: TableUtils): String = {
390+
Option(source.query.partitionColumn).getOrElse(tableUtils.partitionColumn)
391+
}
392+
}
393+
394+
implicit class QuerySparkOps(query: api.Query) {
395+
def effectivePartitionColumn(implicit tableUtils: TableUtils): String =
396+
Option(query).flatMap(q => Option(q.partitionColumn)).getOrElse(tableUtils.partitionColumn)
397+
}
386398
}

spark/src/main/scala/ai/chronon/spark/GroupBy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ object GroupBy {
626626

627627
val intersectedRange: PartitionRange = getIntersectedRange(source, queryRange, tableUtils, window)
628628

629-
var metaColumns: Map[String, String] = Map(tableUtils.partitionColumn -> null)
629+
var metaColumns: Map[String, String] = Map(tableUtils.partitionColumn -> source.query.partitionColumn)
630630
if (mutations) {
631631
metaColumns ++= Map(
632632
Constants.ReversalColumn -> source.query.reversalColumn,

spark/src/main/scala/ai/chronon/spark/Join.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ class Join(joinConf: api.Join,
515515
var bootstrapDf =
516516
tableUtils.scanDf(part.query,
517517
part.table,
518-
Some(Map(tableUtils.partitionColumn -> null)),
518+
Some(Map(part.query.effectivePartitionColumn -> null)),
519519
range = Some(bootstrapRange))
520520

521521
// attach semantic_hash for either log or regular table bootstrap

spark/src/main/scala/ai/chronon/spark/JoinBase.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ abstract class JoinBase(val joinConfCloned: api.Join,
5151
showDf: Boolean = false,
5252
selectedJoinParts: Option[Seq[String]] = None) {
5353
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
54+
implicit val tu = tableUtils
5455
private implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
5556
assert(Option(joinConfCloned.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null")
5657
val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinOffline, joinConfCloned)
@@ -169,7 +170,8 @@ abstract class JoinBase(val joinConfCloned: api.Join,
169170
inputToOutputShift = shiftDays,
170171
// never skip hole during partTable's range determination logic because we don't want partTable
171172
// and joinTable to be out of sync. skipping behavior is already handled in the outer loop.
172-
skipFirstHole = false
173+
skipFirstHole = false,
174+
inputPartitionColumnName = joinConfCloned.left.query.effectivePartitionColumn
173175
)
174176
.getOrElse(Seq())
175177

@@ -345,7 +347,13 @@ abstract class JoinBase(val joinConfCloned: api.Join,
345347

346348
(rangeToFill,
347349
tableUtils
348-
.unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConfCloned.left.table)), skipFirstHole = skipFirstHole)
350+
.unfilledRanges(
351+
outputTable,
352+
rangeToFill,
353+
Some(Seq(joinConfCloned.left.table)),
354+
skipFirstHole = skipFirstHole,
355+
inputPartitionColumnName = joinConfCloned.left.query.effectivePartitionColumn
356+
)
349357
.getOrElse(Seq.empty))
350358
}
351359

@@ -473,7 +481,13 @@ abstract class JoinBase(val joinConfCloned: api.Join,
473481
joinConfCloned.historicalBackfill)
474482
logger.info(s"Join range to fill $rangeToFill")
475483
val unfilledRanges = tableUtils
476-
.unfilledRanges(outputTable, rangeToFill, Some(Seq(joinConfCloned.left.table)), skipFirstHole = skipFirstHole)
484+
.unfilledRanges(
485+
outputTable,
486+
rangeToFill,
487+
Some(Seq(joinConfCloned.left.table)),
488+
skipFirstHole = skipFirstHole,
489+
inputPartitionColumnName = joinConfCloned.left.query.effectivePartitionColumn
490+
)
477491
.getOrElse(Seq.empty)
478492

479493
def finalResult: DataFrame = tableUtils.scanDf(null, outputTable, range = Some(rangeToFill))

spark/src/main/scala/ai/chronon/spark/LabelJoin.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,10 @@
1717
package ai.chronon.spark
1818

1919
import ai.chronon.api
20-
import ai.chronon.api.Constants
20+
import ai.chronon.api.{Builders, Constants, JoinPart, PartitionSpec, TimeUnit, Window}
2121
import ai.chronon.api.DataModel.Entities
2222
import ai.chronon.api.DataModel.Events
2323
import ai.chronon.api.Extensions._
24-
import ai.chronon.api.JoinPart
25-
import ai.chronon.api.PartitionSpec
26-
import ai.chronon.api.TimeUnit
27-
import ai.chronon.api.Window
2824
import ai.chronon.online.Metrics
2925
import ai.chronon.online.PartitionRange
3026
import ai.chronon.spark.Extensions._
@@ -196,10 +192,13 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
196192
s"${joinConf.metaData.name}/${labelJoinPart.groupBy.getMetaData.getName}")
197193
throw e
198194
}
199-
tableUtils.scanDf(query = null,
200-
partTable,
201-
range = Some(labelOutputRange),
202-
partitionColumn = Constants.LabelPartitionColumn)
195+
// We need to drop the partition column on the scanned DF because label join doesn't expect a second `ds`
196+
// On the right side, which will result in a duplicated column error (scan df renames non-default partition cols)
197+
tableUtils
198+
.scanDf(query = Builders.Query(partitionColumn = Constants.LabelPartitionColumn),
199+
partTable,
200+
range = Some(labelOutputRange))
201+
.drop(tableUtils.partitionColumn)
203202
}
204203
}
205204

@@ -266,8 +265,8 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
266265
}
267266

268267
// apply key-renaming to key columns
269-
val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) { case (rightDf, (rightKey, leftKey)) =>
270-
rightDf.withColumnRenamed(rightKey, leftKey)
268+
val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) { case (updatedRight, (rightKey, leftKey)) =>
269+
updatedRight.withColumnRenamed(rightKey, leftKey)
271270
}
272271

273272
val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn,

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
199199
}
200200
}
201201

202-
def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
202+
def partitions(tableName: String,
203+
subPartitionsFilter: Map[String, String] = Map.empty,
204+
partitionColumnName: String = partitionColumn): Seq[String] = {
203205

204206
tableReadFormat(tableName)
205207
.map((format) => {
206-
val partitions = format.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)(sparkSession)
208+
val partitions = format.primaryPartitions(tableName, partitionColumnName, subPartitionsFilter)(sparkSession)
207209

208210
if (partitions.isEmpty) {
209211
logger.info(s"No partitions found for table: $tableName")
@@ -597,7 +599,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
597599
inputTables: Option[Seq[String]] = None,
598600
inputTableToSubPartitionFiltersMap: Map[String, Map[String, String]] = Map.empty,
599601
inputToOutputShift: Int = 0,
600-
skipFirstHole: Boolean = true): Option[Seq[PartitionRange]] = {
602+
skipFirstHole: Boolean = true,
603+
inputPartitionColumnName: String = partitionColumn): Option[Seq[PartitionRange]] = {
601604

602605
val validPartitionRange = if (outputPartitionRange.start == null) { // determine partition range automatically
603606
val inputStart = inputTables.flatMap(_.map(table =>
@@ -637,7 +640,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
637640
inputTables <- inputTables.toSeq;
638641
table <- inputTables;
639642
subPartitionFilters = inputTableToSubPartitionFiltersMap.getOrElse(table, Map.empty);
640-
partitionStr <- partitions(table, subPartitionFilters)
643+
partitionStr <- partitions(table, subPartitionFilters, inputPartitionColumnName)
641644
) yield {
642645
partitionSpec.shift(partitionStr, inputToOutputShift)
643646
}
@@ -828,16 +831,23 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
828831
def scanDf(query: Query,
829832
table: String,
830833
fallbackSelects: Option[Map[String, String]] = None,
831-
range: Option[PartitionRange] = None,
832-
partitionColumn: String = partitionColumn): DataFrame = {
834+
range: Option[PartitionRange] = None): DataFrame = {
833835

834-
val rangeWheres = range.map(whereClauses(_, partitionColumn)).getOrElse(Seq.empty)
836+
val queryPartitionColumn = query.effectivePartitionColumn(this)
837+
838+
val rangeWheres = range.map(whereClauses(_, queryPartitionColumn)).getOrElse(Seq.empty)
835839
val queryWheres = Option(query).flatMap(q => Option(q.wheres)).map(_.toScala).getOrElse(Seq.empty)
836840
val wheres: Seq[String] = rangeWheres ++ queryWheres
837841

838842
val selects = Option(query).flatMap(q => Option(q.selects)).map(_.toScala).getOrElse(Map.empty)
839843

840-
scanDfBase(selects, table, wheres, rangeWheres, fallbackSelects)
844+
val scanDf = scanDfBase(selects, table, wheres, rangeWheres, fallbackSelects)
845+
846+
if (queryPartitionColumn != partitionColumn) {
847+
scanDf.withColumnRenamed(queryPartitionColumn, partitionColumn)
848+
} else {
849+
scanDf
850+
}
841851
}
842852

843853
def partitionRange(table: String): PartitionRange = {

spark/src/test/scala/ai/chronon/spark/test/DataFrameGen.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,37 @@ import scala.collection.Seq
3838
// String types are nulled at row level and also at the set level (some strings are always absent)
3939
object DataFrameGen {
4040
// The main api: that generates dataframes given certain properties of data
41-
def gen(spark: SparkSession, columns: Seq[Column], count: Int): DataFrame = {
41+
def gen(spark: SparkSession, columns: Seq[Column], count: Int, partitionColumn: Option[String] = None): DataFrame = {
4242
val tableUtils = TableUtils(spark)
43-
val RowsWithSchema(rows, schema) = CStream.gen(columns, count, tableUtils.partitionColumn, tableUtils.partitionSpec)
43+
val RowsWithSchema(rows, schema) =
44+
CStream.gen(columns, count, partitionColumn.getOrElse(tableUtils.partitionColumn), tableUtils.partitionSpec)
4445
val genericRows = rows.map { row => new GenericRow(row.fieldsSeq.toArray) }.toArray
4546
val data: RDD[Row] = spark.sparkContext.parallelize(genericRows)
4647
val sparkSchema = SparkConversions.fromChrononSchema(schema)
4748
spark.createDataFrame(data, sparkSchema)
4849
}
4950

5051
// The main api: that generates dataframes given certain properties of data
51-
def events(spark: SparkSession, columns: Seq[Column], count: Int, partitions: Int): DataFrame = {
52+
def events(spark: SparkSession,
53+
columns: Seq[Column],
54+
count: Int,
55+
partitions: Int,
56+
partitionColumn: Option[String] = None): DataFrame = {
57+
val partitionColumnString = partitionColumn.getOrElse(TableUtils(spark).partitionColumn)
5258
val generated = gen(spark, columns :+ Column(Constants.TimeColumn, LongType, partitions), count)
5359
generated.withColumn(
54-
TableUtils(spark).partitionColumn,
60+
partitionColumnString,
5561
from_unixtime(generated.col(Constants.TimeColumn) / 1000, TableUtils(spark).partitionSpec.format))
5662
}
5763

5864
// Generates Entity data
59-
def entities(spark: SparkSession, columns: Seq[Column], count: Int, partitions: Int): DataFrame = {
60-
gen(spark, columns :+ Column(TableUtils(spark).partitionColumn, StringType, partitions), count)
65+
def entities(spark: SparkSession,
66+
columns: Seq[Column],
67+
count: Int,
68+
partitions: Int,
69+
partitionColumn: Option[String] = None): DataFrame = {
70+
val partitionColumnString = partitionColumn.getOrElse(TableUtils(spark).partitionColumn)
71+
gen(spark, columns :+ Column(partitionColumnString, StringType, partitions), count, partitionColumn)
6172
}
6273

6374
/** Mutations and snapshots generation.

0 commit comments

Comments
 (0)