Skip to content

Commit 08e3315

Browse files
committed
tile codec code cleanup
1 parent 4604e61 commit 08e3315

File tree

3 files changed

+36
-20
lines changed

3 files changed

+36
-20
lines changed

spark/src/main/scala/ai/chronon/spark/utils/InMemoryStream.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import ai.chronon.api.StructType
2323
import ai.chronon.online.AvroConversions
2424
import ai.chronon.online.SparkConversions
2525
import ai.chronon.online.TileCodec
26+
import ai.chronon.spark.utils.InMemoryStream.TileUpdate
2627
import ai.chronon.spark.{FastHashing, GenericRowHandler, KeyWithHash, TableUtils}
2728
import org.apache.avro.data.TimeConversions
2829
import org.apache.avro.generic.GenericData
@@ -39,6 +40,10 @@ import org.apache.spark.sql.execution.streaming.MemoryStream
3940
import org.slf4j.Logger
4041
import org.slf4j.LoggerFactory
4142

43+
object InMemoryStream {
44+
case class TileUpdate(keys: Array[Any], ir: Array[Any], tileTimestamp: Long, updateTimestamp: Long)
45+
}
46+
4247
class InMemoryStream {
4348
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
4449

@@ -119,7 +124,7 @@ class InMemoryStream {
119124
*/
120125
def getInMemoryTiledStreamArray(spark: SparkSession,
121126
inputDf: Dataset[Row],
122-
groupBy: GroupBy): Array[(Array[Any], Long, Array[Byte])] = {
127+
groupBy: GroupBy): (Array[TileUpdate], TileCodec) = {
123128

124129
val chrononSchema: StructType = StructType.from("input", SparkConversions.toChrononSchema(inputDf.schema))
125130
val schema = chrononSchema.iterator.map { field =>
@@ -155,7 +160,10 @@ class InMemoryStream {
155160
(keyWithHash, row.getLong(tsIndex))
156161
})
157162

158-
entityTimestampGroupedRows.toArray.map { keyedRow =>
163+
val tileCodec = new TileCodec(groupBy, schema)
164+
// val preAgg: Array[Byte] = tileCodec.makeTileIr(aggIr, isComplete = false)
165+
166+
val updates = entityTimestampGroupedRows.toArray.flatMap { keyedRow =>
159167
val ((KeyWithHash(keys, _, _), tileTimestamp), rows) = keyedRow
160168

161169
val rowAggregator = TileCodec.buildRowAggregator(groupBy, schema)
@@ -174,13 +182,11 @@ class InMemoryStream {
174182
} else {
175183
rowAggregator.delete(aggIr, chrononRow)
176184
}
177-
185+
val updateTimestamp = row.getLong(tsIndex)
186+
TileUpdate(keys, aggIr, tileTimestamp, updateTimestamp)
178187
}
179-
180-
val tileCodec = new TileCodec(groupBy, schema)
181-
val preAgg: Array[Byte] = tileCodec.makeTileIr(aggIr, true)
182-
183-
(keys, tileTimestamp, preAgg)
184188
}
189+
190+
updates -> tileCodec
185191
}
186192
}

spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ import ai.chronon.api.Extensions.GroupByOps
2525
import ai.chronon.api.Extensions.MetadataOps
2626
import ai.chronon.api.Extensions.SourceOps
2727
import ai.chronon.api.TilingUtils
28-
import ai.chronon.online.AvroConversions
29-
import ai.chronon.online.KVStore
28+
import ai.chronon.online.{AvroConversions, KVStore, TileCodec}
3029
import ai.chronon.spark.GenericRowHandler
3130
import ai.chronon.spark.GroupByUpload
3231
import ai.chronon.spark.SparkSessionBuilder
@@ -35,6 +34,7 @@ import ai.chronon.spark.streaming.GroupBy
3534
import ai.chronon.spark.streaming.JoinSourceRunner
3635
import ai.chronon.spark.utils.InMemoryKvStore
3736
import ai.chronon.spark.utils.InMemoryStream
37+
import ai.chronon.spark.utils.InMemoryStream.TileUpdate
3838
import ai.chronon.spark.utils.MockApi
3939
import org.apache.spark.sql.SparkSession
4040
import org.apache.spark.sql.streaming.Trigger
@@ -52,12 +52,16 @@ object OnlineUtils {
5252
debug: Boolean,
5353
dropDsOnWrite: Boolean,
5454
isTiled: Boolean): Unit = {
55+
5556
val inputStreamDf = groupByConf.dataModel match {
57+
5658
case DataModel.Entities =>
5759
val entity = groupByConf.streamingSource.get
5860
val df = tableUtils.sql(s"SELECT * FROM ${entity.getEntities.mutationTable} WHERE ds = '$ds'")
61+
5962
df.withColumnRenamed(entity.query.reversalColumn, Constants.ReversalColumn)
6063
.withColumnRenamed(entity.query.mutationTimeColumn, Constants.MutationTimeColumn)
64+
6165
case DataModel.Events =>
6266
val table = groupByConf.streamingSource.get.table
6367
tableUtils.sql(s"SELECT * FROM $table WHERE ds >= '$ds'")
@@ -66,6 +70,7 @@ object OnlineUtils {
6670
val inputStream = new InMemoryStream
6771
val mockApi = new MockApi(kvStore, namespace)
6872
var inputModified = inputStreamDf
73+
6974
if (dropDsOnWrite && inputStreamDf.schema.fieldNames.contains(tableUtils.partitionColumn)) {
7075
inputModified = inputStreamDf.drop(tableUtils.partitionColumn)
7176
}
@@ -79,43 +84,48 @@ object OnlineUtils {
7984
}
8085

8186
if (isTiled) {
82-
val memoryStream: Array[(Array[Any], Long, Array[Byte])] =
87+
88+
val (memoryStream: Array[TileUpdate], tileCodec: TileCodec) =
8389
inputStream.getInMemoryTiledStreamArray(session, inputModified, groupByConf)
8490
val inMemoryKvStore: KVStore = kvStore()
8591

86-
val fetcher = mockApi.buildFetcher(false)
92+
val fetcher = mockApi.buildFetcher(debug = false)
8793
val groupByServingInfo = fetcher.metadataStore.getGroupByServingInfo(groupByConf.getMetaData.getName).get
8894

8995
val keyZSchema: api.StructType = groupByServingInfo.keyChrononSchema
9096
val keyToBytes = AvroConversions.encodeBytes(keyZSchema, GenericRowHandler.func)
9197

92-
val putRequests = memoryStream.map { entry =>
93-
val keys = entry._1
94-
val timestamp = entry._2
95-
val tileBytes = entry._3
98+
val putRequests = memoryStream.map { entry: TileUpdate =>
99+
val keyBytes = keyToBytes(entry.keys)
100+
val tileIrBytes = tileCodec.makeTileIr(entry.ir, isComplete = false)
96101

97-
val keyBytes = keyToBytes(keys)
98102
val tileKey = TilingUtils.buildTileKey(
99103
groupByConf.streamingDataset,
100104
keyBytes,
101105
Some(ResolutionUtils.getSmallestWindowResolutionInMillis(groupByServingInfo.groupBy)),
102106
None)
107+
103108
KVStore.PutRequest(TilingUtils.serializeTileKey(tileKey),
104-
tileBytes,
109+
tileIrBytes,
105110
groupByConf.streamingDataset,
106-
Some(timestamp))
111+
Some(entry.tileTimestamp))
107112
}
113+
108114
inMemoryKvStore.multiPut(putRequests)
115+
109116
} else {
117+
110118
val groupByStreaming =
111119
new GroupBy(inputStream.getInMemoryStreamDF(session, inputModified),
112120
session,
113121
groupByConf,
114122
mockApi,
115123
debug = debug)
124+
116125
// We modify the arguments for running to make sure all data gets into the KV Store before fetching.
117126
val dataStream = groupByStreaming.buildDataStream()
118127
val query = dataStream.trigger(Trigger.Once()).start()
128+
119129
query.awaitTermination()
120130
}
121131
}

spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByUploadTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class GroupByUploadTest extends AnyFlatSpec {
314314
def cRating(location: Double, cleanliness: Double): java.util.Map[String, Double] =
315315
Map("location" -> location, "cleanliness" -> cleanliness).toJava
316316
val gson = new Gson()
317-
assertEquals(results, requestResponse.map(_._2))
317+
assertEquals(requestResponse.map(_._2), results)
318318

319319
val expectedCategoryRatings = Array(
320320
cRating(4.5, 4.0),

0 commit comments

Comments
 (0)