Skip to content

Commit b87fdbe

Browse files
committed
test fixes
1 parent a355eb5 commit b87fdbe

File tree

9 files changed

+37
-43
lines changed

9 files changed

+37
-43
lines changed

api/src/main/scala/ai/chronon/api/SerdeUtils.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,6 @@ import ai.chronon.api.thrift.protocol.{TBinaryProtocol, TCompactProtocol}
44
import ai.chronon.api.thrift.{TDeserializer, TSerializer}
55

66
object SerdeUtils {
7-
@transient
8-
lazy val binarySerializer: ThreadLocal[TSerializer] = new ThreadLocal[TSerializer] {
9-
override def initialValue(): TSerializer = new TSerializer(new TBinaryProtocol.Factory())
10-
}
11-
12-
@transient
13-
lazy val binaryDeserializer: ThreadLocal[TDeserializer] = new ThreadLocal[TDeserializer] {
14-
override def initialValue(): TDeserializer = new TDeserializer(new TBinaryProtocol.Factory())
15-
}
16-
177
@transient
188
lazy val compactSerializer: ThreadLocal[TSerializer] = new ThreadLocal[TSerializer] {
199
override def initialValue(): TSerializer = new TSerializer(new TCompactProtocol.Factory())

online/src/main/scala/ai/chronon/online/TTLCache.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TTLCache[I, O](f: I => O,
4545

4646
case class Entry(value: O, updatedAtMillis: Long, var markedForUpdate: AtomicBoolean = new AtomicBoolean(false))
4747
@transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass)
48+
4849
private val updateWhenNull =
4950
new function.BiFunction[I, Entry, Entry] {
5051
override def apply(t: I, u: Entry): Entry = {

online/src/main/scala/ai/chronon/online/fetcher/GroupByResponseHandler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class GroupByResponseHandler(fetchContext: FetchContext, metadataStore: Metadata
186186
val groupByFlag: Option[Boolean] = Option(fetchContext.flagStore)
187187
.map(_.isSet(
188188
"disable_streaming_decoding_error_throws",
189-
Map("groupby_streaming_dataset" -> servingInfo.groupByServingInfo.groupBy.getMetaData.getName).toJava))
189+
Map("group_by_streaming_dataset" -> servingInfo.groupByServingInfo.groupBy.getMetaData.getName).toJava))
190190
if (groupByFlag.getOrElse(fetchContext.disableErrorThrows)) {
191191
Array.empty[TiledIr]
192192
} else {

online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
3131

3232
@transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass)
3333

34-
private val groupByFetcher = new GroupByFetcher(fetchContext, metadataStore)
34+
private[online] val groupByFetcher = new GroupByFetcher(fetchContext, metadataStore)
3535
private implicit val executionContext: ExecutionContext = fetchContext.getOrCreateExecutionContext
3636

3737
def fetchGroupBys(requests: Seq[Request]): Future[Seq[Response]] = {

online/src/main/scala/ai/chronon/online/stats/DriftStore.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import ai.chronon.observability._
88
import ai.chronon.online.KVStore
99
import ai.chronon.online.KVStore.GetRequest
1010
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
11+
import org.slf4j.LoggerFactory
1112

1213
import scala.collection.Seq
1314
import scala.concurrent.{ExecutionContext, Future}
@@ -21,6 +22,8 @@ class DriftStore(kvStore: KVStore,
2122
private val metadataStore = new MetadataStore(fetchContext)
2223
implicit private val executionContext: ExecutionContext = fetchContext.getOrCreateExecutionContext
2324

25+
@transient private lazy val logger = LoggerFactory.getLogger(this.getClass)
26+
2427
def tileKeysForJoin(join: api.Join,
2528
slice: Option[String] = None,
2629
columnNamePrefix: Option[String] = None): Map[String, Array[TileKey]] = {
@@ -138,7 +141,9 @@ class DriftStore(kvStore: KVStore,
138141
_ match {
139142
case Success(responseContext) => Some(responseContext)
140143
// TODO instrument failures
141-
case Failure(exception) => exception.printStackTrace(); None
144+
case Failure(exception) =>
145+
logger.error("Failed to fetch summary response", exception)
146+
None
142147
}
143148
}
144149

online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import ai.chronon.online.fetcher.Fetcher.Request
2626
import ai.chronon.online.fetcher.Fetcher.Response
2727
import ai.chronon.online.fetcher.FetcherCache.BatchResponses
2828
import ai.chronon.online.KVStore.TimedValue
29-
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
29+
import ai.chronon.online.fetcher.{FetchContext, GroupByFetcher, MetadataStore}
3030
import ai.chronon.online.{fetcher, _}
3131
import org.junit.Assert.assertEquals
3232
import org.junit.Assert.assertFalse
@@ -57,7 +57,8 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
5757
val HostKey = "host"
5858
val GuestId: AnyRef = 123.asInstanceOf[AnyRef]
5959
val HostId = "456"
60-
var fetcherBase: fetcher.JoinPartFetcher = _
60+
var joinPartFetcher: fetcher.JoinPartFetcher = _
61+
var groupByFetcher: fetcher.GroupByFetcher = _
6162
var kvStore: KVStore = _
6263
var fetchContext: FetchContext = _
6364
var metadataStore: MetadataStore = _
@@ -69,32 +70,32 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
6970
// the mock to prevent hanging.
7071
when(kvStore.executionContext).thenReturn(ExecutionContext.global)
7172
fetchContext = FetchContext(kvStore)
72-
metadataStore = new MetadataStore(fetchContext)
73-
fetcherBase = spy[fetcher.JoinPartFetcher](new fetcher.JoinPartFetcher(fetchContext, metadataStore))
73+
metadataStore = spy[fetcher.MetadataStore](new MetadataStore(fetchContext))
74+
joinPartFetcher = spy[fetcher.JoinPartFetcher](new fetcher.JoinPartFetcher(fetchContext, metadataStore))
75+
groupByFetcher = spy[fetcher.GroupByFetcher](new GroupByFetcher(fetchContext, metadataStore))
7476
}
7577

7678
it should "fetch columns single query" in {
7779
// Fetch a single query
7880
val keyMap = Map(GuestKey -> GuestId)
7981
val query = ColumnSpec(GroupBy, Column, None, Some(keyMap))
80-
8182
doAnswer(new Answer[Future[Seq[fetcher.Fetcher.Response]]] {
8283
def answer(invocation: InvocationOnMock): Future[Seq[Response]] = {
8384
val requests = invocation.getArgument(0).asInstanceOf[Seq[Request]]
8485
val request = requests.head
8586
val response = Response(request, Success(Map(request.name -> "100")))
8687
Future.successful(Seq(response))
8788
}
88-
}).when(fetcherBase).fetchGroupBys(any())
89+
}).when(groupByFetcher).fetchGroupBys(any())
8990

9091
// Map should contain query with valid response
91-
val queryResults = Await.result(fetcherBase.fetchColumns(Seq(query)), 1.second)
92+
val queryResults = Await.result(groupByFetcher.fetchColumns(Seq(query)), 1.second)
9293
queryResults.contains(query) shouldBe true
9394
queryResults.get(query).map(_.values) shouldBe Some(Success(Map(s"$GroupBy.$Column" -> "100")))
9495

9596
// GroupBy request sent to KV store for the query
9697
val requestsCaptor = ArgumentCaptor.forClass(classOf[Seq[_]])
97-
verify(fetcherBase, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
98+
verify(groupByFetcher, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
9899
val actualRequest = requestsCaptor.getValue.asInstanceOf[Seq[Request]].headOption
99100
actualRequest shouldNot be(None)
100101
actualRequest.get.name shouldBe s"${query.groupByName}.${query.columnName}"
@@ -114,18 +115,18 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
114115
val responses = requests.map(r => Response(r, Success(Map(r.name -> "100"))))
115116
Future.successful(responses)
116117
}
117-
}).when(fetcherBase).fetchGroupBys(any())
118+
}).when(groupByFetcher).fetchGroupBys(any())
118119

119120
// Map should contain query with valid response
120-
val queryResults = Await.result(fetcherBase.fetchColumns(Seq(guestQuery, hostQuery)), 1.second)
121+
val queryResults = Await.result(groupByFetcher.fetchColumns(Seq(guestQuery, hostQuery)), 1.second)
121122
queryResults.contains(guestQuery) shouldBe true
122123
queryResults.get(guestQuery).map(_.values) shouldBe Some(Success(Map(s"${GuestKey}_$GroupBy.$Column" -> "100")))
123124
queryResults.contains(hostQuery) shouldBe true
124125
queryResults.get(hostQuery).map(_.values) shouldBe Some(Success(Map(s"${HostKey}_$GroupBy.$Column" -> "100")))
125126

126127
// GroupBy request sent to KV store for the query
127128
val requestsCaptor = ArgumentCaptor.forClass(classOf[Seq[_]])
128-
verify(fetcherBase, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
129+
verify(groupByFetcher, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
129130
val actualRequests = requestsCaptor.getValue.asInstanceOf[Seq[Request]]
130131
actualRequests.length shouldBe 2
131132
actualRequests.head.name shouldBe s"${guestQuery.groupByName}.${guestQuery.columnName}"
@@ -143,10 +144,10 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
143144
def answer(invocation: InvocationOnMock): Future[Seq[Response]] = {
144145
Future.successful(Seq())
145146
}
146-
}).when(fetcherBase).fetchGroupBys(any())
147+
}).when(groupByFetcher).fetchGroupBys(any())
147148

148149
// Map should contain query with Failure response
149-
val queryResults = Await.result(fetcherBase.fetchColumns(Seq(query)), 1.second)
150+
val queryResults = Await.result(groupByFetcher.fetchColumns(Seq(query)), 1.second)
150151
queryResults.contains(query) shouldBe true
151152
queryResults.get(query).map(_.values) match {
152153
case Some(Failure(_: IllegalStateException)) => succeed
@@ -155,7 +156,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
155156

156157
// GroupBy request sent to KV store for the query
157158
val requestsCaptor = ArgumentCaptor.forClass(classOf[Seq[_]])
158-
verify(fetcherBase, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
159+
verify(groupByFetcher, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
159160
val actualRequest = requestsCaptor.getValue.asInstanceOf[Seq[Request]].headOption
160161
actualRequest shouldNot be(None)
161162
actualRequest.get.name shouldBe query.groupByName + "." + query.columnName
@@ -166,16 +167,16 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
166167
it should "get serving info should call update serving info if batch response is from kv store" in {
167168
val oldServingInfo = mock[GroupByServingInfoParsed]
168169
val updatedServingInfo = mock[GroupByServingInfoParsed]
169-
doReturn(updatedServingInfo).when(fetcherBase).getServingInfo(any(), any())
170+
doReturn(updatedServingInfo).when(joinPartFetcher).getServingInfo(any(), any())
170171

171172
val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L)))
172173
val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess)
173174

174-
val result = fetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)
175+
val result = joinPartFetcher.getServingInfo(oldServingInfo, kvStoreBatchResponses)
175176

176177
// updateServingInfo is called
177178
result shouldEqual updatedServingInfo
178-
verify(fetcherBase).getServingInfo(any(), any())
179+
verify(joinPartFetcher).getServingInfo(any(), any())
179180
}
180181

181182
// If a batch response is cached, the serving info should be refreshed. This is needed to prevent
@@ -194,23 +195,23 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
194195
doReturn(groupByOpsMock).when(oldServingInfo).groupByOps
195196

196197
val cachedBatchResponses = BatchResponses(mock[FinalBatchIr])
197-
val result = fetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)
198+
val result = groupByFetcher.getServingInfo(oldServingInfo, cachedBatchResponses)
198199

199200
// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
200201
result shouldEqual oldServingInfo
201202
verify(ttlCache).refresh(any())
202-
verify(fetcherBase, never()).getServingInfo(any(), any())
203+
verify(ttlCache, never()).apply(any())
203204
}
204205

205-
it should "is caching enabled correctly determine if cache is enabled" in {
206+
it should "determine if caching is enabled correctly" in {
206207
val flagStore: FlagStore = (flagName: String, attributes: java.util.Map[String, String]) => {
207208
flagName match {
208209
case "enable_fetcher_batch_ir_cache" =>
209-
attributes.get("groupby_streaming_dataset") match {
210+
attributes.get("group_by_streaming_dataset") match {
210211
case "test_groupby_2" => false
211212
case "test_groupby_3" => true
212213
case other @ _ =>
213-
fail(s"Unexpected groupby_streaming_dataset: $other")
214+
fail(s"Unexpected group_by_streaming_dataset: $other")
214215
false
215216
}
216217
case _ => false
@@ -226,10 +227,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
226227
spy[fetcher.JoinPartFetcher](new fetcher.JoinPartFetcher(fetchContext, new MetadataStore(fetchContext)))
227228
when(fetcherBaseWithFlagStore.isCacheSizeConfigured).thenReturn(true)
228229

229-
def buildGroupByWithCustomJson(name: String): GroupBy = Builders.GroupBy(metaData = Builders.MetaData(name = name))
230-
231230
// no name set
232-
assertFalse(fetchContext.isCachingEnabled(null))
233231
assertFalse(fetchContext.isCachingEnabled("test_groupby_2"))
234232
assertTrue(fetchContext.isCachingEnabled("test_groupby_3"))
235233
}

spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DataServer(driftSeries: Seq[TileDriftSeries], summarySeries: Seq[TileSumma
3737
private def convertToBytesMap[T <: TBase[_, _]: Manifest: ClassTag](
3838
series: T,
3939
keyF: T => TileSeriesKey): Map[String, String] = {
40-
val serializerInstance = SerdeUtils.binarySerializer.get()
40+
val serializerInstance = SerdeUtils.compactSerializer.get()
4141
val encoder = Base64.getEncoder
4242
val keyBytes = serializerInstance.serialize(keyF(series))
4343
val valueBytes = serializerInstance.serialize(series)

spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package ai.chronon.spark.stats.drift
33
import ai.chronon.api.ColorPrinter.ColorString
44
import ai.chronon.api.Extensions._
55
import ai.chronon.api.ScalaJavaConversions._
6-
import ai.chronon.api.SerdeUtils.binarySerializer
6+
import ai.chronon.api.SerdeUtils.compactSerializer
77
import ai.chronon.api._
88
import ai.chronon.observability.Cardinality
99
import ai.chronon.observability.TileKey
@@ -325,7 +325,7 @@ class SummaryPacker(confPath: String,
325325

326326
val packedRdd: RDD[sql.Row] = df.rdd.flatMap(func).map { tileRow =>
327327
// pack into bytes
328-
val serializer = binarySerializer.get()
328+
val serializer = compactSerializer.get()
329329

330330
val partition = tileRow.partition
331331
val timestamp = tileRow.tileTs

spark/src/test/scala/ai/chronon/spark/test/stats/drift/DriftTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class DriftTest extends AnyFlatSpec with Matchers {
4848
logger.info(s" ${pad(f.name)} : ${f.dataType.typeName}".yellow)
4949
}
5050

51-
df.show(10, truncate = false)
51+
df.show(10)
5252
}
5353

5454
"end_to_end" should "fetch prepare anomalous data, summarize, upload and fetch without failures" in {
@@ -57,7 +57,7 @@ class DriftTest extends AnyFlatSpec with Matchers {
5757
val prepareData = PrepareData(namespace)
5858
val join = prepareData.generateAnomalousFraudJoin
5959
val df = prepareData.generateFraudSampleData(600000, "2023-01-01", "2023-02-30", join.metaData.loggedTable)
60-
df.show(10, truncate = false)
60+
df.show(10)
6161

6262
// mock api impl for online fetching and uploading
6363
val kvStoreFunc: () => KVStore = () => {

0 commit comments

Comments
 (0)