Skip to content

Commit 4c75619

Browse files
committed
backing out tiled fetching changes + adding tests for serving conf
1 parent 08e3315 commit 4c75619

File tree

9 files changed

+178
-48
lines changed

9 files changed

+178
-48
lines changed

api/py/ai/chronon/cli/compile/parse_teams.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _merge_mode_maps(
123123
result.backfill = _merge_maps(result.common, result.backfill)
124124
result.upload = _merge_maps(result.common, result.upload)
125125
result.streaming = _merge_maps(result.common, result.streaming)
126+
result.serving = _merge_maps(result.common, result.serving)
126127
result.common = None
127128
continue
128129

@@ -135,5 +136,6 @@ def _merge_mode_maps(
135136
result.streaming = _merge_maps(
136137
result.streaming, mode_map.common, mode_map.streaming
137138
)
139+
result.serving = _merge_maps(result.serving, mode_map.common, mode_map.serving)
138140

139141
return result

api/src/main/scala/ai/chronon/api/Builders.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ object Builders {
271271
tableProperties: Map[String, String] = Map.empty,
272272
historicalBackfill: Boolean = true,
273273
driftSpec: DriftSpec = null,
274-
additionalOutputPartitionColumns: Seq[String] = Seq.empty
274+
additionalOutputPartitionColumns: Seq[String] = Seq.empty,
275+
executionInfo: ExecutionInfo = null
275276
): MetaData = {
276277
val result = new MetaData()
277278
result.setName(name)
@@ -287,9 +288,7 @@ object Builders {
287288
}
288289

289290
result.setTeam(effectiveTeam)
290-
val executionInfo = new ExecutionInfo()
291-
.setHistoricalBackfill(historicalBackfill)
292-
result.setExecutionInfo(executionInfo)
291+
293292
if (samplePercent > 0)
294293
result.setSamplePercent(samplePercent)
295294
if (consistencySamplePercent > 0)
@@ -299,6 +298,14 @@ object Builders {
299298
if (driftSpec != null)
300299
result.setDriftSpec(driftSpec)
301300

301+
if (executionInfo != null) {
302+
result.setExecutionInfo(executionInfo.setHistoricalBackfill(historicalBackfill))
303+
} else {
304+
result.setExecutionInfo(
305+
new ExecutionInfo()
306+
.setHistoricalBackfill(historicalBackfill))
307+
}
308+
302309
if (additionalOutputPartitionColumns.nonEmpty) {
303310
result.setAdditionalOutputPartitionColumns(additionalOutputPartitionColumns.toJava)
304311
}

api/src/main/scala/ai/chronon/api/Extensions.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,22 @@ object Extensions {
579579
QueryParts(allSelects, wheres)
580580
}
581581

582+
def servingFlagValue(flag: String): Option[String] = {
583+
for (
584+
execInfo <- Option(groupBy.metaData.executionInfo);
585+
conf <- Option(execInfo.conf);
586+
servingConf <- Option(conf.serving);
587+
value <- Option(servingConf.get(flag))
588+
) {
589+
return Some(value)
590+
}
591+
None
592+
}
593+
594+
def tilingFlag: Boolean = servingFlagValue("tiling").exists(_.toLowerCase() == "true")
595+
596+
def throwOnDecodeFailFlag: Boolean = servingFlagValue("decode.throw_on_fail").exists(_.toLowerCase() == "true")
597+
582598
// build left streaming query for join source runner
583599
def buildLeftStreamingQuery(query: Query, defaultFieldNames: Seq[String]): String = {
584600
val queryParts = groupBy.buildQueryParts(query)

api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ package ai.chronon.api.test
1818

1919
import ai.chronon.api.Extensions._
2020
import ai.chronon.api.ScalaJavaConversions._
21-
import ai.chronon.api.{Accuracy, Builders, Constants, GroupBy}
22-
import org.junit.Assert.{assertEquals, assertTrue}
21+
import ai.chronon.api.{Accuracy, Builders, ConfigProperties, Constants, ExecutionInfo, GroupBy}
22+
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
2323
import org.mockito.Mockito.{spy, when}
2424
import org.scalatest.flatspec.AnyFlatSpec
2525

@@ -116,4 +116,31 @@ class ExtensionsTest extends AnyFlatSpec {
116116
assertTrue(keys.contains(Constants.TimeColumn))
117117
assertEquals(4, keys.size)
118118
}
119+
120+
it should "is tiling enabled" in {
121+
def buildGroupByWithServingFlags(flags: Map[String, String] = null): GroupByOps = {
122+
123+
val execInfo: ExecutionInfo = if (flags != null) {
124+
new ExecutionInfo()
125+
.setConf(new ConfigProperties().setServing(flags.toJava))
126+
} else {
127+
null
128+
}
129+
130+
Builders.GroupBy(
131+
metaData = Builders.MetaData(name = "featureGroupName", executionInfo = execInfo)
132+
)
133+
134+
}
135+
136+
// customJson not set defaults to false
137+
assertFalse(buildGroupByWithServingFlags().tilingFlag)
138+
assertFalse(buildGroupByWithServingFlags(Map.empty).tilingFlag)
139+
140+
val trueGb = buildGroupByWithServingFlags(Map("tiling" -> "true"))
141+
assertTrue(trueGb.tilingFlag)
142+
assertFalse(buildGroupByWithServingFlags(Map("tiling" -> "false")).tilingFlag)
143+
assertFalse(buildGroupByWithServingFlags(Map("tiling" -> "invalid")).tilingFlag)
144+
145+
}
119146
}

api/thrift/common.thrift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,42 @@ struct DateRange {
2525
2: string endDate
2626
}
2727

28+
/**
29+
* env vars for different modes of execution - with "common" applying to all modes
30+
* the submitter will set these env vars prior to launching the job
31+
*
32+
* these env vars are layered in order of priority
33+
* 1. company file defaults specified in teams.py - in the "common" team
34+
* 2. team wide defaults that apply to all objects in the team folder
35+
* 3. object specific defaults - applies to only the object that are declares them
36+
*
37+
* All the maps from the above three places are merged to create final env var
38+
**/
2839
struct EnvironmentVariables {
2940
1: optional map<string, string> common
3041
2: optional map<string, string> backfill
3142
3: optional map<string, string> upload
3243
4: optional map<string, string> streaming
44+
5: optional map<string, string> serving
3345
}
3446

47+
/**
48+
* job config for different modes of execution - with "common" applying to all modes
49+
* usually these are spark or flink conf params like "spark.executor.memory" etc
50+
*
51+
* these confs are layered in order of priority
52+
* 1. company file defaults specified in teams.py - in the "common" team
53+
* 2. team wide defaults that apply to all objects in the team folder
54+
* 3. object specific defaults - applies to only the object that are declares them
55+
*
56+
* All the maps from the above three places are merged to create final conf map
57+
**/
3558
struct ConfigProperties {
3659
1: optional map<string, string> common
3760
2: optional map<string, string> backfill
3861
3: optional map<string, string> upload
3962
4: optional map<string, string> streaming
63+
5: optional map<string, string> serving
4064
}
4165

4266
struct TableDependency {

online/src/main/scala/ai/chronon/online/fetcher/GroupByFetcher.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class GroupByFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
8686
case Accuracy.TEMPORAL =>
8787
// Build a tile key for the streaming request
8888
// When we build support for layering, we can expand this out into a utility that builds n tile keys for n layers
89-
val keyBytes = {
89+
val keyBytes = if (groupByServingInfo.groupByOps.tilingFlag) {
9090

9191
val tileKey = TilingUtils.buildTileKey(
9292
groupByServingInfo.groupByOps.streamingDataset,
@@ -96,6 +96,8 @@ class GroupByFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
9696
)
9797

9898
TilingUtils.serializeTileKey(tileKey)
99+
} else {
100+
streamingKeyBytes
99101
}
100102

101103
Some(

online/src/main/scala/ai/chronon/online/fetcher/GroupByResponseHandler.scala

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,38 +112,97 @@ class GroupByResponseHandler(fetchContext: FetchContext, metadataStore: Metadata
112112
val batchIr: FinalBatchIr =
113113
getBatchIrFromBatchResponse(batchResponses, batchBytes, servingInfo, toBatchIr, requestContext.keys)
114114

115-
mergeTiledIrsFromStreaming(requestContext.queryTimeMs, servingInfo, streamingResponses, aggregator, batchIr)
115+
if (servingInfo.groupByOps.tilingFlag) {
116+
mergeTiledIrsFromStreaming(requestContext.queryTimeMs, servingInfo, streamingResponses, aggregator, batchIr)
117+
} else {
118+
mergeRawEventsFromStreaming(requestContext.queryTimeMs,
119+
servingInfo,
120+
streamingResponses,
121+
mutations,
122+
aggregator,
123+
batchIr)
124+
}
125+
}
126+
127+
private def mergeRawEventsFromStreaming(queryTimeMs: Long,
128+
servingInfo: GroupByServingInfoParsed,
129+
streamingResponses: Seq[TimedValue],
130+
mutations: Boolean,
131+
aggregator: SawtoothOnlineAggregator,
132+
batchIr: FinalBatchIr): Array[Any] = {
133+
134+
val selectedCodec = servingInfo.groupByOps.dataModel match {
135+
case DataModel.Events => servingInfo.valueAvroCodec
136+
case DataModel.Entities => servingInfo.mutationValueAvroCodec
137+
}
138+
139+
def decodeRow(timedValue: TimedValue): Row = {
140+
val gbName = servingInfo.groupByOps.metaData.getName
141+
Try(selectedCodec.decodeRow(timedValue.bytes, timedValue.millis, mutations)) match {
142+
case Success(row) => row
143+
case Failure(_) =>
144+
logger.error(
145+
s"Failed to decode streaming row for groupBy $gbName" +
146+
"Streaming rows will be ignored")
147+
148+
if (servingInfo.groupByOps.throwOnDecodeFailFlag) {
149+
throw new RuntimeException(s"Failed to decode streaming row for groupBy $gbName")
150+
} else {
151+
null
152+
}
153+
}
154+
}
155+
156+
val streamingRows: Array[Row] =
157+
if (streamingResponses == null) Array.empty
158+
else
159+
streamingResponses.iterator
160+
.filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis)
161+
.map(decodeRow)
162+
.filter(_ != null)
163+
.toArray
164+
165+
if (fetchContext.debug) {
166+
val gson = new Gson()
167+
logger.info(s"""
168+
|batch ir: ${gson.toJson(batchIr)}
169+
|streamingRows: ${gson.toJson(streamingRows)}
170+
|batchEnd in millis: ${servingInfo.batchEndTsMillis}
171+
|queryTime in millis: $queryTimeMs
172+
|""".stripMargin)
173+
}
174+
175+
aggregator.lambdaAggregateFinalized(batchIr, streamingRows.iterator, queryTimeMs, mutations)
116176
}
117177

118178
private def mergeTiledIrsFromStreaming(queryTimeMs: Long,
119179
servingInfo: GroupByServingInfoParsed,
120180
streamingResponses: Seq[TimedValue],
121181
aggregator: SawtoothOnlineAggregator,
122182
batchIr: FinalBatchIr): Array[Any] = {
123-
val streamingIrs: Option[Array[TiledIr]] = Option(streamingResponses)
124-
.map(_.iterator
125-
.filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis)
126-
.flatMap { tVal =>
127-
Try(servingInfo.tiledCodec.decodeTileIr(tVal.bytes)) match {
128-
case Success((tile, _)) => Array(TiledIr(tVal.millis, tile))
129-
case Failure(_) =>
130-
logger.error(
131-
s"Failed to decode tile ir for groupBy ${servingInfo.groupByOps.metaData.getName}" +
132-
"Streaming tiled IRs will be ignored")
133-
val groupByFlag: Option[Boolean] = Option(fetchContext.flagStore)
134-
.map(_.isSet(
135-
"disable_streaming_decoding_error_throws",
136-
Map(
137-
"group_by_streaming_dataset" -> servingInfo.groupByServingInfo.groupBy.getMetaData.getName).toJava))
138-
if (groupByFlag.getOrElse(fetchContext.disableErrorThrows)) {
139-
Array.empty[TiledIr]
140-
} else {
141-
throw new RuntimeException(
142-
s"Failed to decode tile ir for groupBy ${servingInfo.groupByOps.metaData.getName}")
143-
}
144-
}
183+
val streamingIrs: Iterator[TiledIr] = streamingResponses.iterator
184+
.filter(tVal => tVal.millis >= servingInfo.batchEndTsMillis)
185+
.flatMap { tVal =>
186+
Try(servingInfo.tiledCodec.decodeTileIr(tVal.bytes)) match {
187+
case Success((tile, _)) => Array(TiledIr(tVal.millis, tile))
188+
case Failure(_) =>
189+
logger.error(
190+
s"Failed to decode tile ir for groupBy ${servingInfo.groupByOps.metaData.getName}" +
191+
"Streaming tiled IRs will be ignored")
192+
val groupByFlag: Option[Boolean] = Option(fetchContext.flagStore)
193+
.map(_.isSet(
194+
"disable_streaming_decoding_error_throws",
195+
Map("group_by_streaming_dataset" -> servingInfo.groupByServingInfo.groupBy.getMetaData.getName).toJava))
196+
if (groupByFlag.getOrElse(fetchContext.disableErrorThrows)) {
197+
Array.empty[TiledIr]
198+
} else {
199+
throw new RuntimeException(
200+
s"Failed to decode tile ir for groupBy ${servingInfo.groupByOps.metaData.getName}")
201+
}
145202
}
146-
.toArray)
203+
}
204+
.toArray
205+
.iterator
147206

148207
if (fetchContext.debug) {
149208
val gson = new Gson()
@@ -155,7 +214,7 @@ class GroupByResponseHandler(fetchContext: FetchContext, metadataStore: Metadata
155214
|""".stripMargin)
156215
}
157216

158-
aggregator.lambdaAggregateFinalizedTiled(batchIr, streamingIrs.map(_.iterator).orNull, queryTimeMs)
217+
aggregator.lambdaAggregateFinalizedTiled(batchIr, streamingIrs, queryTimeMs)
159218
}
160219

161220
private def reportKvResponse(ctx: Metrics.Context, response: Seq[TimedValue], queryTsMillis: Long): Unit = {

spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,25 @@ object OnlineUtils {
171171
// we need to fix the quirk and drop this flag
172172
dropDsOnWrite: Boolean = false,
173173
tilingEnabled: Boolean = false): Unit = {
174+
174175
val prevDs = tableUtils.partitionSpec.before(endDs)
175176
GroupByUpload.run(groupByConf, prevDs, Some(tableUtils))
176177
inMemoryKvStore.bulkPut(groupByConf.metaData.uploadTable, groupByConf.batchDataset, null)
178+
177179
if (groupByConf.inferredAccuracy == Accuracy.TEMPORAL && groupByConf.streamingSource.isDefined) {
178180
val streamingSource = groupByConf.streamingSource.get
179181
inMemoryKvStore.create(groupByConf.streamingDataset)
182+
180183
if (streamingSource.isSetJoinSource) {
181184
inMemoryKvStore.create(Constants.MetadataDataset)
182185
new MockApi(kvStoreGen, namespace)
183186
.buildFetcher()
184187
.metadataStore
185188
.putJoinConf(streamingSource.getJoinSource.getJoin)
186189
OnlineUtils.putStreamingNew(groupByConf, endDs, namespace, kvStoreGen, debug)(tableUtils.sparkSession)
190+
187191
} else {
192+
188193
OnlineUtils.putStreaming(tableUtils.sparkSession,
189194
groupByConf,
190195
kvStoreGen,
@@ -194,6 +199,7 @@ object OnlineUtils {
194199
debug,
195200
dropDsOnWrite,
196201
tilingEnabled)
202+
197203
}
198204
}
199205
}

spark/src/test/scala/ai/chronon/spark/test/groupby/GroupByUploadTest.scala

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -261,22 +261,9 @@ class GroupByUploadTest extends AnyFlatSpec {
261261

262262
// DO-NOT-SET debug=true here since the streaming job won't put data into kv store
263263
joinConf.joinParts.toScala.foreach(jp =>
264-
OnlineUtils.serve(tableUtils,
265-
kvStore,
266-
kvStoreFunc,
267-
"chaining_test",
268-
endDs,
269-
jp.groupBy,
270-
dropDsOnWrite = true,
271-
tilingEnabled = true))
264+
OnlineUtils.serve(tableUtils, kvStore, kvStoreFunc, "chaining_test", endDs, jp.groupBy, dropDsOnWrite = true))
272265

273-
OnlineUtils.serve(tableUtils,
274-
kvStore,
275-
kvStoreFunc,
276-
"chaining_test",
277-
endDs,
278-
listingRatingGroupBy,
279-
tilingEnabled = true)
266+
OnlineUtils.serve(tableUtils, kvStore, kvStoreFunc, "chaining_test", endDs, listingRatingGroupBy, debug = false)
280267

281268
kvStoreFunc().show()
282269

@@ -314,7 +301,7 @@ class GroupByUploadTest extends AnyFlatSpec {
314301
def cRating(location: Double, cleanliness: Double): java.util.Map[String, Double] =
315302
Map("location" -> location, "cleanliness" -> cleanliness).toJava
316303
val gson = new Gson()
317-
assertEquals(requestResponse.map(_._2), results)
304+
assertEquals(results, requestResponse.map(_._2))
318305

319306
val expectedCategoryRatings = Array(
320307
cRating(4.5, 4.0),

0 commit comments

Comments
 (0)