Skip to content

Commit 8548a63

Browse files
committed
rebase
1 parent 76f83d5 commit 8548a63

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

spark/src/main/scala/ai/chronon/spark/GroupBy.scala

+18-10
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,13 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
165165
// add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000)
166166
val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis)
167167
val sawtoothAggregator = new SawtoothAggregator(aggregations, selectedSchema, resolution)
168+
val sawtoothAggregatorBroadcast = sparkSession.sparkContext.broadcast(sawtoothAggregator)
168169
val hops = hopsAggregate(endTimes.min, resolution)
169170

170171
hops
171172
.flatMap { case (keys, hopsArrays) =>
172173
// filter out if the all the irs are nulls
173-
val irs = sawtoothAggregator.computeWindows(hopsArrays, shiftedEndTimes)
174+
val irs = sawtoothAggregatorBroadcast.value.computeWindows(hopsArrays, shiftedEndTimes)
174175
irs.indices.flatMap { i =>
175176
val result = normalizeOrFinalize(irs(i))
176177
if (result.forall(_ == null)) None
@@ -230,16 +231,21 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
230231
val snapshotKeyHashFx = FastHashing.generateKeyBuilder(keyColumns.toArray, expandedInputDf.schema)
231232
val sawtoothAggregator =
232233
new SawtoothMutationAggregator(aggregations, SparkConversions.toChrononSchema(expandedInputDf.schema), resolution)
234+
235+
val sawtoothAggregatorBroadcast = sparkSession.sparkContext.broadcast(sawtoothAggregator)
233236
val updateFunc = (ir: BatchIr, row: Row) => {
234-
sawtoothAggregator.update(row.getLong(shiftedColumnIndexTs), ir, SparkConversions.toChrononRow(row, tsIndex))
237+
sawtoothAggregatorBroadcast.value.update(row.getLong(shiftedColumnIndexTs),
238+
ir,
239+
SparkConversions.toChrononRow(row, tsIndex))
235240
ir
236241
}
237242

238243
// end of day IR
239244
val snapshotByKeys = expandedInputDf.rdd
240245
.keyBy(row => (snapshotKeyHashFx(row), row.getString(shiftedColumnIndex)))
241-
.aggregateByKey(sawtoothAggregator.init)(seqOp = updateFunc, combOp = sawtoothAggregator.merge)
242-
.mapValues(sawtoothAggregator.finalizeSnapshot)
246+
.aggregateByKey(sawtoothAggregatorBroadcast.value.init)(seqOp = updateFunc,
247+
combOp = sawtoothAggregatorBroadcast.value.merge)
248+
.mapValues(sawtoothAggregatorBroadcast.value.finalizeSnapshot)
243249

244250
// Preprocess for mutations: Add a ds of mutation ts column, collect sorted mutations by keys and ds of mutation.
245251
val mutationDf = mutationDfFn()
@@ -270,10 +276,10 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
270276
val sortedQueries = timeQueries.map { TimeTuple.getTs }
271277
val finalizedEodIr = eodIr.orNull
272278

273-
val irs = sawtoothAggregator.lambdaAggregateIrMany(tableUtils.partitionSpec.epochMillis(ds),
274-
finalizedEodIr,
275-
dayMutations.orNull,
276-
sortedQueries)
279+
val irs = sawtoothAggregatorBroadcast.value.lambdaAggregateIrMany(tableUtils.partitionSpec.epochMillis(ds),
280+
finalizedEodIr,
281+
dayMutations.orNull,
282+
sortedQueries)
277283
((keyWithHash, ds), (timeQueries, sortedQueries.indices.map(i => normalizeOrFinalize(irs(i)))))
278284
}
279285

@@ -329,14 +335,16 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
329335
val sawtoothAggregator =
330336
new SawtoothAggregator(aggregations, selectedSchema, resolution)
331337

338+
val sawtoothBroadcast = sparkSession.sparkContext.broadcast(sawtoothAggregator)
339+
332340
// create the IRs up to minHop accuracy
333341
val headStartsWithIrs = queriesByHeadStarts.keys
334342
.groupByKey()
335343
.leftOuterJoin(hopsRdd)
336344
.flatMap { case (keys, (headStarts, hopsOpt)) =>
337345
val headStartsArray = headStarts.toArray
338346
util.Arrays.sort(headStartsArray)
339-
val headStartIrs = sawtoothAggregator.computeWindows(hopsOpt.orNull, headStartsArray)
347+
val headStartIrs = sawtoothBroadcast.value.computeWindows(hopsOpt.orNull, headStartsArray)
340348
headStartsArray.indices.map { i => (keys, headStartsArray(i)) -> headStartIrs(i) }
341349
}
342350

@@ -361,7 +369,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
361369
eventsOpt.map(_.map(SparkConversions.toChrononRow(_, tsIndex)).iterator).orNull
362370
}
363371
val queries = queriesWithPartition.map { TimeTuple.getTs }
364-
val irs = sawtoothAggregator.cumulate(inputsIt, queries, headStartIrOpt.orNull)
372+
val irs = sawtoothBroadcast.value.cumulate(inputsIt, queries, headStartIrOpt.orNull)
365373
queries.indices.map { i =>
366374
(keys.data ++ queriesWithPartition(i).toArray, normalizeOrFinalize(irs(i)))
367375
}

spark/src/main/scala/ai/chronon/spark/submission/ChrononKryoRegistrator.scala

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package ai.chronon.spark.submission
1717

18+
import ai.chronon.aggregator.windowing.SawtoothAggregator
19+
import com.esotericsoftware.kryo.serializers.JavaSerializer
1820
import ai.chronon.aggregator.base.FrequentItemType.{DoubleItemType, LongItemType, StringItemType}
1921
import ai.chronon.aggregator.base.FrequentItemsFriendly._
2022
import ai.chronon.aggregator.base.{FrequentItemType, FrequentItemsFriendly, ItemsSketchIR}
@@ -216,6 +218,7 @@ class ChrononKryoRegistrator extends KryoRegistrator {
216218
try {
217219
kryo.register(Class.forName(name))
218220
kryo.register(Class.forName(s"[L$name;")) // represents array of a type to jvm
221+
kryo.register(classOf[SawtoothAggregator], new JavaSerializer)
219222
} catch {
220223
case _: ClassNotFoundException => // do nothing
221224
}

0 commit comments

Comments
 (0)