Skip to content

Commit d324397

Browse files
committed
float up gb fetcher methods
1 parent 24c7c8a commit d324397

File tree

3 files changed

+38
-18
lines changed

3 files changed

+38
-18
lines changed

online/src/main/scala/ai/chronon/online/fetcher/Fetcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class Fetcher(val kvStore: KVStore,
123123
}
124124

125125
def fetchGroupBys(requests: Seq[Request]): Future[Seq[Response]] = {
126-
joinPartFetcher.groupByFetcher.fetchGroupBys(requests)
126+
joinPartFetcher.fetchGroupBys(requests)
127127
}
128128

129129
def fetchJoin(requests: Seq[Request], joinConf: Option[api.Join] = None): Future[Seq[Response]] = {

online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ package ai.chronon.online.fetcher
1919
import ai.chronon.api.Extensions._
2020
import ai.chronon.api._
2121
import ai.chronon.online._
22-
import ai.chronon.online.fetcher.Fetcher.{PrefixedRequest, Request, Response}
22+
import ai.chronon.online.fetcher.Fetcher.{ColumnSpec, PrefixedRequest, Request, Response}
23+
import ai.chronon.online.fetcher.FetcherCache.BatchResponses
2324
import org.slf4j.{Logger, LoggerFactory}
2425

2526
import scala.collection.Seq
@@ -30,9 +31,28 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
3031

3132
@transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass)
3233

33-
private[online] val groupByFetcher = new GroupByFetcher(fetchContext, metadataStore)
34+
private val groupByFetcher = new GroupByFetcher(fetchContext, metadataStore)
3435
private implicit val executionContext: ExecutionContext = fetchContext.getOrCreateExecutionContext
3536

37+
def fetchGroupBys(requests: Seq[Request]): Future[Seq[Response]] = {
38+
groupByFetcher.fetchGroupBys(requests)
39+
}
40+
41+
// ----- START -----
42+
// floated up to makes tests easy
43+
def fetchColumns(specs: Seq[ColumnSpec]): Future[Map[ColumnSpec, Response]] = {
44+
groupByFetcher.fetchColumns(specs)
45+
}
46+
47+
def getServingInfo(existing: GroupByServingInfoParsed, batchResponses: BatchResponses): GroupByServingInfoParsed = {
48+
groupByFetcher.getServingInfo(existing, batchResponses)
49+
}
50+
51+
def isCacheSizeConfigured: Boolean = {
52+
groupByFetcher.isCacheSizeConfigured
53+
}
54+
// ---- END ----
55+
3656
// prioritize passed in joinOverrides over the ones in metadata store
3757
// used in stream-enrichment and in staging testing
3858
def fetchJoins(requests: Seq[Request], joinConf: Option[Join] = None): Future[Seq[Response]] = {

online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
8585
val response = Response(request, Success(Map(request.name -> "100")))
8686
Future.successful(Seq(response))
8787
}
88-
}).when(fetcherBase).groupByFetcher.fetchGroupBys(any())
88+
}).when(fetcherBase).fetchGroupBys(any())
8989

9090
// Map should contain query with valid response
91-
val queryResults = Await.result(fetcherBase.groupByFetcher.fetchColumns(Seq(query)), 1.second)
91+
val queryResults = Await.result(fetcherBase.fetchColumns(Seq(query)), 1.second)
9292
queryResults.contains(query) shouldBe true
9393
queryResults.get(query).map(_.values) shouldBe Some(Success(Map(s"$GroupBy.$Column" -> "100")))
9494

9595
// GroupBy request sent to KV store for the query
9696
val requestsCaptor = ArgumentCaptor.forClass(classOf[Seq[_]])
97-
verify(fetcherBase, times(1)).groupByFetcher.fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
97+
verify(fetcherBase, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
9898
val actualRequest = requestsCaptor.getValue.asInstanceOf[Seq[Request]].headOption
9999
actualRequest shouldNot be(None)
100100
actualRequest.get.name shouldBe s"${query.groupByName}.${query.columnName}"
@@ -114,18 +114,18 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
114114
val responses = requests.map(r => Response(r, Success(Map(r.name -> "100"))))
115115
Future.successful(responses)
116116
}
117-
}).when(fetcherBase).groupByFetcher.fetchGroupBys(any())
117+
}).when(fetcherBase).fetchGroupBys(any())
118118

119119
// Map should contain query with valid response
120-
val queryResults = Await.result(fetcherBase.groupByFetcher.fetchColumns(Seq(guestQuery, hostQuery)), 1.second)
120+
val queryResults = Await.result(fetcherBase.fetchColumns(Seq(guestQuery, hostQuery)), 1.second)
121121
queryResults.contains(guestQuery) shouldBe true
122122
queryResults.get(guestQuery).map(_.values) shouldBe Some(Success(Map(s"${GuestKey}_$GroupBy.$Column" -> "100")))
123123
queryResults.contains(hostQuery) shouldBe true
124124
queryResults.get(hostQuery).map(_.values) shouldBe Some(Success(Map(s"${HostKey}_$GroupBy.$Column" -> "100")))
125125

126126
// GroupBy request sent to KV store for the query
127127
val requestsCaptor = ArgumentCaptor.forClass(classOf[Seq[_]])
128-
verify(fetcherBase, times(1)).groupByFetcher.fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
128+
verify(fetcherBase, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
129129
val actualRequests = requestsCaptor.getValue.asInstanceOf[Seq[Request]]
130130
actualRequests.length shouldBe 2
131131
actualRequests.head.name shouldBe s"${guestQuery.groupByName}.${guestQuery.columnName}"
@@ -143,10 +143,10 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
143143
def answer(invocation: InvocationOnMock): Future[Seq[Response]] = {
144144
Future.successful(Seq())
145145
}
146-
}).when(fetcherBase).groupByFetcher.fetchGroupBys(any())
146+
}).when(fetcherBase).fetchGroupBys(any())
147147

148148
// Map should contain query with Failure response
149-
val queryResults = Await.result(fetcherBase.groupByFetcher.fetchColumns(Seq(query)), 1.second)
149+
val queryResults = Await.result(fetcherBase.fetchColumns(Seq(query)), 1.second)
150150
queryResults.contains(query) shouldBe true
151151
queryResults.get(query).map(_.values) match {
152152
case Some(Failure(_: IllegalStateException)) => succeed
@@ -155,7 +155,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
155155

156156
// GroupBy request sent to KV store for the query
157157
val requestsCaptor = ArgumentCaptor.forClass(classOf[Seq[_]])
158-
verify(fetcherBase, times(1)).groupByFetcher.fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
158+
verify(fetcherBase, times(1)).fetchGroupBys(requestsCaptor.capture().asInstanceOf[Seq[Request]])
159159
val actualRequest = requestsCaptor.getValue.asInstanceOf[Seq[Request]].headOption
160160
actualRequest shouldNot be(None)
161161
actualRequest.get.name shouldBe query.groupByName + "." + query.columnName
@@ -166,16 +166,16 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
166166
it should "get serving info should call update serving info if batch response is from kv store" in {
167167
val oldServingInfo = mock[GroupByServingInfoParsed]
168168
val updatedServingInfo = mock[GroupByServingInfoParsed]
169-
doReturn(updatedServingInfo).when(fetcherBase).groupByFetcher.getServingInfo(any(), any())
169+
doReturn(updatedServingInfo).when(fetcherBase).getServingInfo(any(), any())
170170

171171
val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L)))
172172
val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess)
173173

174-
val result = fetcherBase.groupByFetcher.getServingInfo(oldServingInfo, kvStoreBatchResponses)
174+
val result = fetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)
175175

176176
// updateServingInfo is called
177177
result shouldEqual updatedServingInfo
178-
verify(fetcherBase).groupByFetcher.getServingInfo(any(), any())
178+
verify(fetcherBase).getServingInfo(any(), any())
179179
}
180180

181181
// If a batch response is cached, the serving info should be refreshed. This is needed to prevent
@@ -194,12 +194,12 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
194194
doReturn(groupByOpsMock).when(oldServingInfo).groupByOps
195195

196196
val cachedBatchResponses = BatchResponses(mock[FinalBatchIr])
197-
val result = fetcherBase.groupByFetcher.getServingInfo(oldServingInfo, cachedBatchResponses)
197+
val result = fetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)
198198

199199
// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
200200
result shouldEqual oldServingInfo
201201
verify(ttlCache).refresh(any())
202-
verify(fetcherBase, never()).groupByFetcher.getServingInfo(any(), any())
202+
verify(fetcherBase, never()).getServingInfo(any(), any())
203203
}
204204

205205
it should "is caching enabled correctly determine if cache is enabled" in {
@@ -224,7 +224,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
224224

225225
val fetcherBaseWithFlagStore =
226226
spy[fetcher.JoinPartFetcher](new fetcher.JoinPartFetcher(fetchContext, new MetadataStore(fetchContext)))
227-
when(fetcherBaseWithFlagStore.groupByFetcher.isCacheSizeConfigured).thenReturn(true)
227+
when(fetcherBaseWithFlagStore.isCacheSizeConfigured).thenReturn(true)
228228

229229
def buildGroupByWithCustomJson(name: String): GroupBy = Builders.GroupBy(metaData = Builders.MetaData(name = name))
230230

0 commit comments

Comments
 (0)