@@ -26,7 +26,7 @@ import ai.chronon.online.fetcher.Fetcher.Request
26
26
import ai .chronon .online .fetcher .Fetcher .Response
27
27
import ai .chronon .online .fetcher .FetcherCache .BatchResponses
28
28
import ai .chronon .online .KVStore .TimedValue
29
- import ai .chronon .online .fetcher .{FetchContext , MetadataStore }
29
+ import ai .chronon .online .fetcher .{FetchContext , GroupByFetcher , MetadataStore }
30
30
import ai .chronon .online .{fetcher , _ }
31
31
import org .junit .Assert .assertEquals
32
32
import org .junit .Assert .assertFalse
@@ -57,7 +57,8 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
57
57
val HostKey = " host"
58
58
val GuestId : AnyRef = 123 .asInstanceOf [AnyRef ]
59
59
val HostId = " 456"
60
- var fetcherBase : fetcher.JoinPartFetcher = _
60
+ var joinPartFetcher : fetcher.JoinPartFetcher = _
61
+ var groupByFetcher : fetcher.GroupByFetcher = _
61
62
var kvStore : KVStore = _
62
63
var fetchContext : FetchContext = _
63
64
var metadataStore : MetadataStore = _
@@ -69,32 +70,32 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
69
70
// the mock to prevent hanging.
70
71
when(kvStore.executionContext).thenReturn(ExecutionContext .global)
71
72
fetchContext = FetchContext (kvStore)
72
- metadataStore = new MetadataStore (fetchContext)
73
- fetcherBase = spy[fetcher.JoinPartFetcher ](new fetcher.JoinPartFetcher (fetchContext, metadataStore))
73
+ metadataStore = spy[fetcher.MetadataStore ](new MetadataStore (fetchContext))
74
+ joinPartFetcher = spy[fetcher.JoinPartFetcher ](new fetcher.JoinPartFetcher (fetchContext, metadataStore))
75
+ groupByFetcher = spy[fetcher.GroupByFetcher ](new GroupByFetcher (fetchContext, metadataStore))
74
76
}
75
77
76
78
it should " fetch columns single query" in {
77
79
// Fetch a single query
78
80
val keyMap = Map (GuestKey -> GuestId )
79
81
val query = ColumnSpec (GroupBy , Column , None , Some (keyMap))
80
-
81
82
doAnswer(new Answer [Future [Seq [fetcher.Fetcher .Response ]]] {
82
83
def answer (invocation : InvocationOnMock ): Future [Seq [Response ]] = {
83
84
val requests = invocation.getArgument(0 ).asInstanceOf [Seq [Request ]]
84
85
val request = requests.head
85
86
val response = Response (request, Success (Map (request.name -> " 100" )))
86
87
Future .successful(Seq (response))
87
88
}
88
- }).when(fetcherBase ).fetchGroupBys(any())
89
+ }).when(groupByFetcher ).fetchGroupBys(any())
89
90
90
91
// Map should contain query with valid response
91
- val queryResults = Await .result(fetcherBase .fetchColumns(Seq (query)), 1 .second)
92
+ val queryResults = Await .result(groupByFetcher .fetchColumns(Seq (query)), 1 .second)
92
93
queryResults.contains(query) shouldBe true
93
94
queryResults.get(query).map(_.values) shouldBe Some (Success (Map (s " $GroupBy. $Column" -> " 100" )))
94
95
95
96
// GroupBy request sent to KV store for the query
96
97
val requestsCaptor = ArgumentCaptor .forClass(classOf [Seq [_]])
97
- verify(fetcherBase , times(1 )).fetchGroupBys(requestsCaptor.capture().asInstanceOf [Seq [Request ]])
98
+ verify(groupByFetcher , times(1 )).fetchGroupBys(requestsCaptor.capture().asInstanceOf [Seq [Request ]])
98
99
val actualRequest = requestsCaptor.getValue.asInstanceOf [Seq [Request ]].headOption
99
100
actualRequest shouldNot be(None )
100
101
actualRequest.get.name shouldBe s " ${query.groupByName}. ${query.columnName}"
@@ -114,18 +115,18 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
114
115
val responses = requests.map(r => Response (r, Success (Map (r.name -> " 100" ))))
115
116
Future .successful(responses)
116
117
}
117
- }).when(fetcherBase ).fetchGroupBys(any())
118
+ }).when(groupByFetcher ).fetchGroupBys(any())
118
119
119
120
// Map should contain query with valid response
120
- val queryResults = Await .result(fetcherBase .fetchColumns(Seq (guestQuery, hostQuery)), 1 .second)
121
+ val queryResults = Await .result(groupByFetcher .fetchColumns(Seq (guestQuery, hostQuery)), 1 .second)
121
122
queryResults.contains(guestQuery) shouldBe true
122
123
queryResults.get(guestQuery).map(_.values) shouldBe Some (Success (Map (s " ${GuestKey }_ $GroupBy. $Column" -> " 100" )))
123
124
queryResults.contains(hostQuery) shouldBe true
124
125
queryResults.get(hostQuery).map(_.values) shouldBe Some (Success (Map (s " ${HostKey }_ $GroupBy. $Column" -> " 100" )))
125
126
126
127
// GroupBy request sent to KV store for the query
127
128
val requestsCaptor = ArgumentCaptor .forClass(classOf [Seq [_]])
128
- verify(fetcherBase , times(1 )).fetchGroupBys(requestsCaptor.capture().asInstanceOf [Seq [Request ]])
129
+ verify(groupByFetcher , times(1 )).fetchGroupBys(requestsCaptor.capture().asInstanceOf [Seq [Request ]])
129
130
val actualRequests = requestsCaptor.getValue.asInstanceOf [Seq [Request ]]
130
131
actualRequests.length shouldBe 2
131
132
actualRequests.head.name shouldBe s " ${guestQuery.groupByName}. ${guestQuery.columnName}"
@@ -143,10 +144,10 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
143
144
def answer (invocation : InvocationOnMock ): Future [Seq [Response ]] = {
144
145
Future .successful(Seq ())
145
146
}
146
- }).when(fetcherBase ).fetchGroupBys(any())
147
+ }).when(groupByFetcher ).fetchGroupBys(any())
147
148
148
149
// Map should contain query with Failure response
149
- val queryResults = Await .result(fetcherBase .fetchColumns(Seq (query)), 1 .second)
150
+ val queryResults = Await .result(groupByFetcher .fetchColumns(Seq (query)), 1 .second)
150
151
queryResults.contains(query) shouldBe true
151
152
queryResults.get(query).map(_.values) match {
152
153
case Some (Failure (_ : IllegalStateException )) => succeed
@@ -155,7 +156,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
155
156
156
157
// GroupBy request sent to KV store for the query
157
158
val requestsCaptor = ArgumentCaptor .forClass(classOf [Seq [_]])
158
- verify(fetcherBase , times(1 )).fetchGroupBys(requestsCaptor.capture().asInstanceOf [Seq [Request ]])
159
+ verify(groupByFetcher , times(1 )).fetchGroupBys(requestsCaptor.capture().asInstanceOf [Seq [Request ]])
159
160
val actualRequest = requestsCaptor.getValue.asInstanceOf [Seq [Request ]].headOption
160
161
actualRequest shouldNot be(None )
161
162
actualRequest.get.name shouldBe query.groupByName + " ." + query.columnName
@@ -166,16 +167,16 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
166
167
it should " get serving info should call update serving info if batch response is from kv store" in {
167
168
val oldServingInfo = mock[GroupByServingInfoParsed ]
168
169
val updatedServingInfo = mock[GroupByServingInfoParsed ]
169
- doReturn(updatedServingInfo).when(fetcherBase ).getServingInfo(any(), any())
170
+ doReturn(updatedServingInfo).when(joinPartFetcher ).getServingInfo(any(), any())
170
171
171
172
val batchTimedValuesSuccess = Success (Seq (TimedValue (Array (1 .toByte), 2000L )))
172
173
val kvStoreBatchResponses = BatchResponses (batchTimedValuesSuccess)
173
174
174
- val result = fetcherBase .getServingInfo(oldServingInfo, kvStoreBatchResponses)
175
+ val result = joinPartFetcher .getServingInfo(oldServingInfo, kvStoreBatchResponses)
175
176
176
177
// updateServingInfo is called
177
178
result shouldEqual updatedServingInfo
178
- verify(fetcherBase ).getServingInfo(any(), any())
179
+ verify(joinPartFetcher ).getServingInfo(any(), any())
179
180
}
180
181
181
182
// If a batch response is cached, the serving info should be refreshed. This is needed to prevent
@@ -194,23 +195,23 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
194
195
doReturn(groupByOpsMock).when(oldServingInfo).groupByOps
195
196
196
197
val cachedBatchResponses = BatchResponses (mock[FinalBatchIr ])
197
- val result = fetcherBase .getServingInfo(oldServingInfo, cachedBatchResponses)
198
+ val result = groupByFetcher .getServingInfo(oldServingInfo, cachedBatchResponses)
198
199
199
200
// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
200
201
result shouldEqual oldServingInfo
201
202
verify(ttlCache).refresh(any())
202
- verify(fetcherBase , never()).getServingInfo(any(), any())
203
+ verify(ttlCache , never()).apply( any())
203
204
}
204
205
205
- it should " is caching enabled correctly determine if cache is enabled" in {
206
+ it should " determine if caching is enabled correctly " in {
206
207
val flagStore : FlagStore = (flagName : String , attributes : java.util.Map [String , String ]) => {
207
208
flagName match {
208
209
case " enable_fetcher_batch_ir_cache" =>
209
- attributes.get(" groupby_streaming_dataset " ) match {
210
+ attributes.get(" group_by_streaming_dataset " ) match {
210
211
case " test_groupby_2" => false
211
212
case " test_groupby_3" => true
212
213
case other @ _ =>
213
- fail(s " Unexpected groupby_streaming_dataset : $other" )
214
+ fail(s " Unexpected group_by_streaming_dataset : $other" )
214
215
false
215
216
}
216
217
case _ => false
@@ -226,10 +227,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
226
227
spy[fetcher.JoinPartFetcher ](new fetcher.JoinPartFetcher (fetchContext, new MetadataStore (fetchContext)))
227
228
when(fetcherBaseWithFlagStore.isCacheSizeConfigured).thenReturn(true )
228
229
229
- def buildGroupByWithCustomJson (name : String ): GroupBy = Builders .GroupBy (metaData = Builders .MetaData (name = name))
230
-
231
230
// no name set
232
- assertFalse(fetchContext.isCachingEnabled(null ))
233
231
assertFalse(fetchContext.isCachingEnabled(" test_groupby_2" ))
234
232
assertTrue(fetchContext.isCachingEnabled(" test_groupby_3" ))
235
233
}
0 commit comments