Skip to content

Commit 2cfe42d

Browse files
authored
[WIP] Drift metrics (#59)
## Summary ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a method for converting objects to a pretty-printed JSON string format. - Added functionality for calculating drift metrics between `TileSummary` instances. - Enhanced drift analysis capabilities with new metrics and structures. - New endpoints for model prediction and model drift in the API. - Introduced utility functions for transforming and aggregating data related to `TileSummary` and `TileDrift`. - Enhanced metadata handling with new constants and improved dataset references. - Added a method for processing percentiles and breakpoints to generate interval assignments. - **Bug Fixes** - Improved error handling in various methods for better clarity and logging. - **Refactor** - Renamed variables and methods for clarity and consistency. - Updated method signatures to accommodate new features and improve usability. - Consolidated import statements for better organization. - Removed deprecated objects and methods to streamline functionality. - **Tests** - Added comprehensive unit tests for drift metrics and pivot functionality. - Enhanced test coverage for new and modified features. - Removed outdated tests and added new tests for handling key mappings in joins. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 6bd677f commit 2cfe42d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1704
-964
lines changed

api/src/main/scala/ai/chronon/api/Builders.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ object Builders {
267267
samplePercent: Double = 100,
268268
consistencySamplePercent: Double = 5,
269269
tableProperties: Map[String, String] = Map.empty,
270-
historicalBackill: Boolean = true
270+
historicalBackfill: Boolean = true,
271+
driftSpec: DriftSpec = null
271272
): MetaData = {
272273
val result = new MetaData()
273274
result.setName(name)
@@ -283,7 +284,7 @@ object Builders {
283284
}
284285

285286
result.setTeam(effectiveTeam)
286-
result.setHistoricalBackfill(historicalBackill)
287+
result.setHistoricalBackfill(historicalBackfill)
287288
if (dependencies != null)
288289
result.setDependencies(dependencies.toSeq.toJava)
289290
if (samplePercent > 0)
@@ -292,6 +293,8 @@ object Builders {
292293
result.setConsistencySamplePercent(consistencySamplePercent)
293294
if (tableProperties.nonEmpty)
294295
result.setTableProperties(tableProperties.toJava)
296+
if (driftSpec != null)
297+
result.setDriftSpec(driftSpec)
295298
result
296299
}
297300
}

api/src/main/scala/ai/chronon/api/Constants.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ object Constants {
3737
val ChrononDynamicTable = "chronon_dynamic_table"
3838
val ChrononOOCTable: String = "chronon_ooc_table"
3939
val ChrononLogTable: String = "chronon_log_table"
40-
val ChrononMetadataKey = "ZIPLINE_METADATA"
40+
val MetadataDataset = "CHRONON_METADATA"
4141
val SchemaPublishEvent = "SCHEMA_PUBLISH_EVENT"
4242
val StatsBatchDataset = "CHRONON_STATS_BATCH"
4343
val ConsistencyMetricsDataset = "CHRONON_CONSISTENCY_METRICS_STATS_BATCH"
@@ -62,5 +62,8 @@ object Constants {
6262
val LabelViewPropertyFeatureTable: String = "feature_table"
6363
val LabelViewPropertyKeyLabelTable: String = "label_table"
6464
val ChrononRunDs: String = "CHRONON_RUN_DS"
65-
val DriftStatsTable: String = "drift_statistics"
65+
66+
val TiledSummaryDataset: String = "TILE_SUMMARIES"
67+
68+
val DefaultDriftTileSize: Window = new Window(30, TimeUnit.MINUTES)
6669
}

api/src/main/scala/ai/chronon/api/Extensions.scala

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,15 @@ object Extensions {
158158
val teamOverride = Try(customJsonLookUp(Constants.TeamOverride).asInstanceOf[String]).toOption
159159
teamOverride.getOrElse(metaData.team)
160160
}
161+
162+
// if drift spec is set but tile size is not set, default to 30 minutes
163+
def driftTileSize: Option[Window] = {
164+
Option(metaData.getDriftSpec) match {
165+
case Some(driftSpec) =>
166+
Option(driftSpec.getTileSize).orElse(Some(Constants.DefaultDriftTileSize))
167+
case None => None
168+
}
169+
}
161170
}
162171

163172
// one per output column - so single window
@@ -879,24 +888,69 @@ object Extensions {
879888
partHashes ++ Map(leftSourceKey -> leftHash, join.metaData.bootstrapTable -> bootstrapHash) ++ derivedHashMap
880889
}
881890

882-
/*
883-
External features computed in online env and logged
884-
This method will get the external feature column names
885-
*/
886-
def getExternalFeatureCols: Seq[String] = {
887-
Option(join.onlineExternalParts)
888-
.map(_.toScala
889-
.map { part =>
890-
{
891-
val keys = part.source.getKeySchema.params.toScala
892-
.map(_.name)
893-
val values = part.source.getValueSchema.params.toScala
894-
.map(_.name)
895-
keys ++ values
891+
def externalPartColumns: Map[String, Array[String]] =
892+
Option(join.onlineExternalParts) match {
893+
case Some(parts) =>
894+
parts.toScala.map { part =>
895+
val keys = part.source.getKeySchema.params.toScala.map(_.name)
896+
val values = part.source.getValueSchema.params.toScala.map(_.name)
897+
part.fullName -> (keys ++ values).toArray
898+
}.toMap
899+
case None => Map.empty
900+
}
901+
902+
def derivedColumns: Array[String] =
903+
Option(join.getDerivations) match {
904+
case Some(derivations) =>
905+
derivations.toScala.flatMap { derivation =>
906+
derivation.getName match {
907+
case "*" => None
908+
case _ => Some(derivation.getName)
896909
}
897-
}
898-
.flatMap(_.toSet))
899-
.getOrElse(Seq.empty)
910+
}.toArray
911+
case None => Array.empty
912+
}
913+
914+
// renamed cols are no longer part of the output
915+
private def renamedColumns: Set[String] =
916+
Option(join.derivations)
917+
.map {
918+
_.toScala.renameOnlyDerivations.map(_.expression).toSet
919+
}
920+
.getOrElse(Set.empty)
921+
922+
def joinPartColumns: Map[String, Array[String]] =
923+
Option(join.getJoinParts) match {
924+
case None => Map.empty
925+
case Some(parts) =>
926+
parts.toScala.map { part =>
927+
val prefix = Option(part.prefix)
928+
val groupByName = part.getGroupBy.getMetaData.cleanName
929+
val partName = (prefix.toSeq :+ groupByName).mkString("_")
930+
931+
val outputColumns = part.getGroupBy.valueColumns
932+
val cols = outputColumns.map { column =>
933+
(prefix.toSeq :+ groupByName :+ column).mkString("_")
934+
}
935+
partName -> cols
936+
}.toMap
937+
}
938+
939+
def outputColumnsByGroup: Map[String, Array[String]] = {
940+
val preDeriveCols = (joinPartColumns ++ externalPartColumns)
941+
val preDerivedWithoutRenamed = preDeriveCols.mapValues(_.filterNot(renamedColumns.contains))
942+
val derivedColumns: Array[String] = Option(join.derivations) match {
943+
case Some(derivations) => derivations.toScala.map { _.getName }.filter(_ == "*").toArray
944+
case None => Array.empty
945+
}
946+
preDerivedWithoutRenamed ++ Map("derivations" -> derivedColumns)
947+
}
948+
949+
def keyColumns: Array[String] = {
950+
val joinPartKeys = join.joinParts.toScala.flatMap(_.groupBy.keyColumns.toScala).toSet
951+
val externalKeys = join.onlineExternalParts.toScala.flatMap(_.source.keyNames).toSet
952+
val bootstrapKeys = join.bootstrapParts.toScala.flatMap(_.keyColumns.toScala).toSet
953+
(joinPartKeys ++ externalKeys ++ bootstrapKeys).toArray
900954
}
901955

902956
/*

api/src/main/scala/ai/chronon/api/ThriftJsonCodec.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ai.chronon.api.thrift.protocol.TSimpleJSONProtocol
2525
import com.fasterxml.jackson.databind.DeserializationFeature
2626
import com.fasterxml.jackson.databind.JsonNode
2727
import com.fasterxml.jackson.databind.ObjectMapper
28+
import com.google.gson.GsonBuilder
2829
import org.slf4j.Logger
2930
import org.slf4j.LoggerFactory
3031

@@ -48,6 +49,13 @@ object ThriftJsonCodec {
4849
new String(serializer.serialize(obj), Constants.UTF8)
4950
}
5051

52+
@transient private lazy val prettyGson = new GsonBuilder().setPrettyPrinting().create()
53+
def toPrettyJsonStr[T <: TBase[_, _]: Manifest](obj: T): String = {
54+
val raw = toJsonStr(obj)
55+
val je = prettyGson.fromJson(raw, classOf[com.google.gson.JsonElement])
56+
prettyGson.toJson(je)
57+
}
58+
5159
def toJsonList[T <: TBase[_, _]: Manifest](obj: util.List[T]): String = {
5260
if (obj == null) return ""
5361
obj.toScala

api/thrift/api.thrift

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,12 @@ enum Cardinality {
234234
+----------------------------------+-------------------+----------------+----------------------------------+
235235
| Hellinger Distance | 0.1 - 0.25 | > 0.25 | Ranges from 0 to 1 |
236236
+----------------------------------+-------------------+----------------+----------------------------------+
237-
| Kolmogorov-Smirnov (K-S) | 0.1 - 0.2 | > 0.2 | Ranges from 0 to 1 |
238-
| Distance | | | |
239-
+----------------------------------+-------------------+----------------+----------------------------------+
240237
| Population Stability Index (PSI) | 0.1 - 0.2 | > 0.2 | Industry standard in some fields |
241238
+----------------------------------+-------------------+----------------+----------------------------------+
242239
**/
243240
enum DriftMetric {
244241
JENSEN_SHANNON = 0,
245242
HELLINGER = 1,
246-
KOLMOGOROV_SMIRNOV = 2,
247243
PSI = 3
248244
}
249245

@@ -254,7 +250,10 @@ struct TileKey {
254250
4: optional i64 sizeMillis
255251
}
256252

257-
struct TileSummaries {
253+
// summary of distribution & coverage etc for a given (table, column, slice, tileWindow)
254+
// for categorical types, distribution is histogram, otherwise percentiles
255+
// we also handle container types by counting inner value distribution and inner value coverage
256+
struct TileSummary {
258257
1: optional list<double> percentiles
259258
2: optional map<string, i64> histogram
260259
3: optional i64 count
@@ -269,6 +268,72 @@ struct TileSummaries {
269268
8: optional list<i32> stringLengthPercentiles
270269
}
271270

271+
struct TileSeriesKey {
272+
1: optional string column // name of the column - avg_txns
273+
2: optional string slice // value of the slice - merchant_category
274+
3: optional string groupName // name of the columnGroup within node, for join - joinPart name, externalPart name etc
275+
4: optional string nodeName // name of the node - join name etc
276+
}
277+
278+
// array of tuples of (TileSummary, timestamp) ==(pivot)==> TileSummarySeries
279+
struct TileSummarySeries {
280+
1: optional list<list<double>> percentiles
281+
2: optional map<string, list<i64>> histogram
282+
3: optional list<i64> count
283+
4: optional list<i64> nullCount
284+
285+
// for container types
286+
5: optional list<i64> innerCount // total of number of entries within all containers of this column
287+
6: optional list<i64> innerNullCount
288+
7: optional list<list<i32>> lengthPercentiles
289+
290+
// high cardinality string type
291+
8: optional list<list<i32>> stringLengthPercentiles
292+
293+
200: optional list<i64> timestamps
294+
300: optional TileSeriesKey key
295+
}
296+
297+
// (DriftMetric + old TileSummary + new TileSummary) = TileDrift
298+
struct TileDrift {
299+
300+
// for continuous values - scalar values or within containers
301+
// (lists - for eg. via last_k or maps for eg. via bucketing)
302+
1: optional double percentileDrift
303+
// for categorical values - scalar values or within containers
304+
2: optional double histogramDrift
305+
306+
// for all types
307+
3: optional double countChangePercent
308+
4: optional double nullRatioChangePercent
309+
310+
// additional tracking for container types
311+
5: optional double innerCountChangePercent // total of number of entries within all containers of this column
312+
6: optional double innerNullCountChangePercent
313+
7: optional double lengthPercentilesDrift
314+
315+
// additional tracking for string types
316+
8: optional double stringLengthPercentilesDrift
317+
}
318+
319+
// PivotUtils.pivot(Array[(Long, TileDrift)]) = TileDriftSeries
320+
// used in front end after this is computed
321+
struct TileDriftSeries {
322+
1: optional list<double> percentileDriftSeries
323+
2: optional list<double> histogramDriftSeries
324+
3: optional list<double> countChangePercentSeries
325+
4: optional list<double> nullRatioChangePercentSeries
326+
327+
5: optional list<double> innerCountChangePercentSeries
328+
6: optional list<double> innerNullCountChangePercentSeries
329+
7: optional list<double> lengthPercentilesDriftSeries
330+
8: optional list<double> stringLengthPercentilesDriftSeries
331+
332+
200: optional list<i64> timestamps
333+
334+
300: optional TileSeriesKey key
335+
}
336+
272337
struct DriftSpec {
273338
// slices is another key to summarize the data with - besides the column & slice
274339
// currently supports only one slice
@@ -279,9 +344,19 @@ struct DriftSpec {
279344
// likes_over_dislines = IF(dislikes > likes, 1, 0)
280345
// or any other expression that you care about
281346
2: optional map<string, string> derivations
347+
282348
// we measure the unique counts of the columns and decide if they are categorical and numeric
283349
// you can use this to override that decision by setting cardinality hints
284350
3: optional map<string, Cardinality> columnCardinalityHints
351+
352+
4: optional Window tileSize
353+
354+
// the current tile summary will be compared with older summaries using the metric
355+
// if the drift is more than the threshold, we will raise an alert
356+
5: optional list<Window> lookbackWindows
357+
358+
// default drift metric to use
359+
6: optional DriftMetric driftMetric = DriftMetric.JENSEN_SHANNON
285360
}
286361

287362
struct MetaData {
@@ -315,6 +390,9 @@ struct MetaData {
315390
// Flag to indicate whether join backfill should backfill previous holes.
316391
// Setting to false will only backfill latest single partition
317392
14: optional bool historicalBackfill
393+
394+
// specify how to compute drift
395+
15: optional DriftSpec driftSpec
318396
}
319397

320398

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ lazy val online = project
122122
"com.datadoghq" % "java-dogstatsd-client" % "4.4.1",
123123
"org.rogach" %% "scallop" % "5.1.0",
124124
"net.jodah" % "typetools" % "0.6.3",
125-
"com.github.ben-manes.caffeine" % "caffeine" % "3.1.8"
125+
"com.github.ben-manes.caffeine" % "caffeine" % "3.1.8",
126126
),
127127
libraryDependencies ++= jackson,
128128
libraryDependencies ++= spark_all.map(_ % "provided"),

cloud_aws/src/main/scala/ai/chronon/integrations/aws/DynamoDBKVStoreImpl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
127127
override def multiGet(requests: Seq[KVStore.GetRequest]): Future[Seq[KVStore.GetResponse]] = {
128128
// partition our requests into pure get style requests (where we're missing timestamps and only have key lookup)
129129
// and query requests (we want to query a range based on afterTsMillis -> endTsMillis or now() )
130-
val (getLookups, queryLookups) = requests.partition(r => r.afterTsMillis.isEmpty)
130+
val (getLookups, queryLookups) = requests.partition(r => r.startTsMillis.isEmpty)
131131
val getItemRequestPairs = getLookups.map { req =>
132132
val keyAttributeMap = primaryKeyMap(req.keyBytes)
133133
(req, GetItemRequest.builder.key(keyAttributeMap.asJava).tableName(req.dataset).build)
@@ -325,7 +325,7 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
325325
val partitionAlias = "#pk"
326326
val timeAlias = "#ts"
327327
val attrNameAliasMap = Map(partitionAlias -> partitionKeyColumn, timeAlias -> sortKeyColumn)
328-
val startTs = request.afterTsMillis.get
328+
val startTs = request.startTsMillis.get
329329
val endTs = request.endTsMillis.getOrElse(System.currentTimeMillis())
330330
val attrValuesMap =
331331
Map(

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class BigTableKVStoreImpl(projectId: String, instanceId: String) extends KVStore
4747

4848
val queryTime = System.currentTimeMillis()
4949
// scan from afterTsMillis to now - skip events with future timestamps
50-
request.afterTsMillis.foreach { ts =>
50+
request.startTsMillis.foreach { ts =>
5151
// Bigtable uses microseconds
5252
query.filter(Filters.FILTERS.timestamp().range().startOpen(ts * 1000).endClosed(queryTime))
5353
}

docs/source/setup/Online_Integration.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ If you'd to start with an example, please refer to the [MongoDB Implementation i
1818

1919
```scala
2020
object KVStore {
21-
// `afterTsMillis` implies that this is a range scan of all values with `timestamp` >= to the specified one. This can be implemented efficiently, if `timestamp` can be a secondary key. Some databases have a native version id concept which also can map to timestamp.
22-
case class GetRequest(keyBytes: Array[Byte], dataset: String, afterTsMillis: Option[Long] = None)
21+
// `startTsMillis` implies that this is a range scan of all values with `timestamp` >= to the specified one. This can be implemented efficiently, if `timestamp` can be a secondary key. Some databases have a native version id concept which also can map to timestamp.
22+
case class GetRequest(keyBytes: Array[Byte], dataset: String, startTsMillis: Option[Long] = None)
2323

2424
// response is a series of values that are
2525
case class TimedValue(bytes: Array[Byte], millis: Long)

online/src/main/java/ai/chronon/online/JavaFetcher.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,6 @@ private Metrics.Context getGroupByContext(String groupByName) {
141141
return new Metrics.Context("group_by.fetch", null, groupByName, null, false, null, null, null, null);
142142
}
143143

144-
public CompletableFuture<JavaSeriesStatsResponse> fetchStatsTimeseries(JavaStatsRequest request) {
145-
Future<Fetcher.SeriesStatsResponse> response = this.fetcher.fetchStatsTimeseries(request.toScalaRequest());
146-
// Convert responses to CompletableFuture
147-
return FutureConverters.toJava(response).toCompletableFuture().thenApply(JavaFetcher::toJavaSeriesStatsResponse);
148-
}
149-
150-
public CompletableFuture<JavaSeriesStatsResponse> fetchLogStatsTimeseries(JavaStatsRequest request) {
151-
Future<Fetcher.SeriesStatsResponse> response = this.fetcher.fetchLogStatsTimeseries(request.toScalaRequest());
152-
// Convert responses to CompletableFuture
153-
return FutureConverters.toJava(response).toCompletableFuture().thenApply(JavaFetcher::toJavaSeriesStatsResponse);
154-
}
155-
156144
public CompletableFuture<JavaSeriesStatsResponse> fetchConsistencyMetricsTimeseries(JavaStatsRequest request) {
157145
Future<Fetcher.SeriesStatsResponse> response = this.fetcher.fetchConsistencyMetricsTimeseries(request.toScalaRequest());
158146
// Convert responses to CompletableFuture

online/src/main/scala/ai/chronon/online/Api.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ object KVStore {
4646
// endTsMillis - end range of the scan (starts from afterTsMillis to endTsMillis)
4747
case class GetRequest(keyBytes: Array[Byte],
4848
dataset: String,
49-
afterTsMillis: Option[Long] = None,
49+
startTsMillis: Option[Long] = None,
5050
endTsMillis: Option[Long] = None)
5151
case class TimedValue(bytes: Array[Byte], millis: Long)
5252
case class GetResponse(request: GetRequest, values: Try[Seq[TimedValue]]) {
@@ -261,7 +261,7 @@ abstract class Api(userConf: Map[String, String]) extends Serializable {
261261
callerName: String = null,
262262
disableErrorThrows: Boolean = false): Fetcher =
263263
new Fetcher(genKvStore,
264-
Constants.ChrononMetadataKey,
264+
Constants.MetadataDataset,
265265
logFunc = responseConsumer,
266266
debug = debug,
267267
externalSourceRegistry = externalRegistry,
@@ -272,7 +272,7 @@ abstract class Api(userConf: Map[String, String]) extends Serializable {
272272

273273
final def buildJavaFetcher(callerName: String = null, disableErrorThrows: Boolean = false): JavaFetcher = {
274274
new JavaFetcher(genKvStore,
275-
Constants.ChrononMetadataKey,
275+
Constants.MetadataDataset,
276276
timeoutMillis,
277277
responseConsumer,
278278
externalRegistry,

0 commit comments

Comments
 (0)