@@ -165,12 +165,13 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
165
165
// add 1 day to the end times to include data [ds 00:00:00.000, ds + 1 00:00:00.000)
166
166
val shiftedEndTimes = endTimes.map(_ + tableUtils.partitionSpec.spanMillis)
167
167
val sawtoothAggregator = new SawtoothAggregator (aggregations, selectedSchema, resolution)
168
+ val sawtoothAggregatorBroadcast = sparkSession.sparkContext.broadcast(sawtoothAggregator)
168
169
val hops = hopsAggregate(endTimes.min, resolution)
169
170
170
171
hops
171
172
.flatMap { case (keys, hopsArrays) =>
172
173
// filter out if the all the irs are nulls
173
- val irs = sawtoothAggregator .computeWindows(hopsArrays, shiftedEndTimes)
174
+ val irs = sawtoothAggregatorBroadcast.value .computeWindows(hopsArrays, shiftedEndTimes)
174
175
irs.indices.flatMap { i =>
175
176
val result = normalizeOrFinalize(irs(i))
176
177
if (result.forall(_ == null )) None
@@ -230,16 +231,21 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
230
231
val snapshotKeyHashFx = FastHashing .generateKeyBuilder(keyColumns.toArray, expandedInputDf.schema)
231
232
val sawtoothAggregator =
232
233
new SawtoothMutationAggregator (aggregations, SparkConversions .toChrononSchema(expandedInputDf.schema), resolution)
234
+
235
+ val sawtoothAggregatorBroadcast = sparkSession.sparkContext.broadcast(sawtoothAggregator)
233
236
val updateFunc = (ir : BatchIr , row : Row ) => {
234
- sawtoothAggregator.update(row.getLong(shiftedColumnIndexTs), ir, SparkConversions .toChrononRow(row, tsIndex))
237
+ sawtoothAggregatorBroadcast.value.update(row.getLong(shiftedColumnIndexTs),
238
+ ir,
239
+ SparkConversions .toChrononRow(row, tsIndex))
235
240
ir
236
241
}
237
242
238
243
// end of day IR
239
244
val snapshotByKeys = expandedInputDf.rdd
240
245
.keyBy(row => (snapshotKeyHashFx(row), row.getString(shiftedColumnIndex)))
241
- .aggregateByKey(sawtoothAggregator.init)(seqOp = updateFunc, combOp = sawtoothAggregator.merge)
242
- .mapValues(sawtoothAggregator.finalizeSnapshot)
246
+ .aggregateByKey(sawtoothAggregatorBroadcast.value.init)(seqOp = updateFunc,
247
+ combOp = sawtoothAggregatorBroadcast.value.merge)
248
+ .mapValues(sawtoothAggregatorBroadcast.value.finalizeSnapshot)
243
249
244
250
// Preprocess for mutations: Add a ds of mutation ts column, collect sorted mutations by keys and ds of mutation.
245
251
val mutationDf = mutationDfFn()
@@ -270,10 +276,10 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
270
276
val sortedQueries = timeQueries.map { TimeTuple .getTs }
271
277
val finalizedEodIr = eodIr.orNull
272
278
273
- val irs = sawtoothAggregator .lambdaAggregateIrMany(tableUtils.partitionSpec.epochMillis(ds),
274
- finalizedEodIr,
275
- dayMutations.orNull,
276
- sortedQueries)
279
+ val irs = sawtoothAggregatorBroadcast.value .lambdaAggregateIrMany(tableUtils.partitionSpec.epochMillis(ds),
280
+ finalizedEodIr,
281
+ dayMutations.orNull,
282
+ sortedQueries)
277
283
((keyWithHash, ds), (timeQueries, sortedQueries.indices.map(i => normalizeOrFinalize(irs(i)))))
278
284
}
279
285
@@ -329,14 +335,16 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
329
335
val sawtoothAggregator =
330
336
new SawtoothAggregator (aggregations, selectedSchema, resolution)
331
337
338
+ val sawtoothBroadcast = sparkSession.sparkContext.broadcast(sawtoothAggregator)
339
+
332
340
// create the IRs up to minHop accuracy
333
341
val headStartsWithIrs = queriesByHeadStarts.keys
334
342
.groupByKey()
335
343
.leftOuterJoin(hopsRdd)
336
344
.flatMap { case (keys, (headStarts, hopsOpt)) =>
337
345
val headStartsArray = headStarts.toArray
338
346
util.Arrays .sort(headStartsArray)
339
- val headStartIrs = sawtoothAggregator .computeWindows(hopsOpt.orNull, headStartsArray)
347
+ val headStartIrs = sawtoothBroadcast.value .computeWindows(hopsOpt.orNull, headStartsArray)
340
348
headStartsArray.indices.map { i => (keys, headStartsArray(i)) -> headStartIrs(i) }
341
349
}
342
350
@@ -361,7 +369,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
361
369
eventsOpt.map(_.map(SparkConversions .toChrononRow(_, tsIndex)).iterator).orNull
362
370
}
363
371
val queries = queriesWithPartition.map { TimeTuple .getTs }
364
- val irs = sawtoothAggregator .cumulate(inputsIt, queries, headStartIrOpt.orNull)
372
+ val irs = sawtoothBroadcast.value .cumulate(inputsIt, queries, headStartIrOpt.orNull)
365
373
queries.indices.map { i =>
366
374
(keys.data ++ queriesWithPartition(i).toArray, normalizeOrFinalize(irs(i)))
367
375
}
0 commit comments