Skip to content

Commit a4f162c

Browse files
authored
chore: simplify some Extensions (#165)
## Summary - Slim down `Extensions`. Remove some unused methods and use builtins for the others. ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Simplified collection conversions between Scala and Java Maps in multiple Spark utility classes - Removed two implicit utility classes related to internal row and tuple handling - **Chores** - Updated code to use more idiomatic Scala collection transformations - Maintained existing functionality while improving code clarity <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 476c432 commit a4f162c

File tree

4 files changed

+20
-37
lines changed

4 files changed

+20
-37
lines changed

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ import ai.chronon.online.SparkConversions
2626
import ai.chronon.online.TimeRange
2727
import org.apache.avro.Schema
2828
import org.apache.spark.sql.DataFrame
29-
import org.apache.spark.sql.Row
3029
import org.apache.spark.sql.SparkSession
31-
import org.apache.spark.sql.catalyst.InternalRow
3230
import org.apache.spark.sql.expressions.UserDefinedFunction
3331
import org.apache.spark.sql.functions._
3432
import org.apache.spark.sql.types.LongType
@@ -298,30 +296,6 @@ object Extensions {
298296
}
299297
}
300298

301-
implicit class InternalRowOps(internalRow: InternalRow) {
302-
def toRow: Row = {
303-
new Row() {
304-
override def length: Int = {
305-
internalRow.numFields
306-
}
307-
308-
override def get(i: Int): Any = {
309-
internalRow.get(i, schema.fields(i).dataType)
310-
}
311-
312-
override def copy(): Row = internalRow.copy().toRow
313-
}
314-
}
315-
}
316-
317-
implicit class TupleToJMapOps[K, V](tuples: Iterator[(K, V)]) {
318-
def toJMap: util.Map[K, V] = {
319-
val map = new util.HashMap[K, V]()
320-
tuples.foreach { case (k, v) => map.put(k, v) }
321-
map
322-
}
323-
}
324-
325299
implicit class DataPointerOps(dataPointer: DataPointer) {
326300
def toDf(implicit sparkSession: SparkSession): DataFrame = {
327301
val tableOrPath = dataPointer.tableOrPath

spark/src/main/scala/ai/chronon/spark/Join.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,12 @@ class Join(joinConf: api.Join,
274274
if (skipBloomFilter) {
275275
None
276276
} else {
277-
val leftBlooms = joinConf.leftKeyCols.iterator.map { key =>
278-
key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange)
279-
}.toJMap
277+
val leftBlooms = joinConf.leftKeyCols.iterator
278+
.map { key =>
279+
key -> bootstrapDf.generateBloomFilter(key, leftRowCount, joinConf.left.table, leftRange)
280+
}
281+
.toMap
282+
.asJava
280283
Some(leftBlooms)
281284
}
282285
}

spark/src/main/scala/ai/chronon/spark/JoinUtils.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,13 @@ object JoinUtils {
304304
joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]]): Option[util.Map[String, BloomFilter]] = {
305305

306306
val rightBlooms = joinLevelBloomMapOpt.map { joinBlooms =>
307-
joinPart.rightToLeft.iterator.map {
308-
case (rightCol, leftCol) =>
309-
rightCol -> joinBlooms.get(leftCol)
310-
}.toJMap
307+
joinPart.rightToLeft.iterator
308+
.map {
309+
case (rightCol, leftCol) =>
310+
rightCol -> joinBlooms.get(leftCol)
311+
}
312+
.toMap
313+
.asJava
311314
}
312315

313316
// print bloom sizes

spark/src/main/scala/ai/chronon/spark/LabelJoin.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,12 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
156156

157157
def computeRange(leftDf: DataFrame, leftRange: PartitionRange, sanitizedLabelDs: String): DataFrame = {
158158
val leftDfCount = leftDf.count()
159-
val leftBlooms = labelJoinConf.leftKeyCols.iterator.map { key =>
160-
key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange)
161-
}.toJMap
159+
val leftBlooms = labelJoinConf.leftKeyCols.iterator
160+
.map { key =>
161+
key -> leftDf.generateBloomFilter(key, leftDfCount, joinConf.left.table, leftRange)
162+
}
163+
.toMap
164+
.asJava
162165

163166
// compute joinParts in parallel
164167
val rightDfs = labelJoinConf.labels.asScala.map { labelJoinPart =>
@@ -241,7 +244,7 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
241244
PartitionRange(labelDS, labelDS),
242245
tableUtils,
243246
computeDependency = true,
244-
Option(rightBloomMap.iterator.toJMap),
247+
Option(rightBloomMap.toMap.asJava),
245248
rightSkewFilter)
246249

247250
val df = (joinConf.left.dataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match {

0 commit comments

Comments
 (0)