8
8
import io .vertx .core .Future ;
9
9
import io .vertx .core .Vertx ;
10
10
import com .google .cloud .storage .Blob ;
11
-
12
11
import org .junit .jupiter .api .BeforeEach ;
13
12
import org .junit .jupiter .api .extension .ExtendWith ;
14
13
import org .junit .jupiter .api .Test ;
15
14
import org .mockito .Mock ;
16
- import org .mockito .Mockito ;
17
15
import org .mockito .junit .jupiter .MockitoExtension ;
18
16
import org .prebid .server .exception .PreBidException ;
19
17
import org .prebid .server .hooks .modules .greenbids .real .time .data .model .predictor .ModelCache ;
33
31
34
32
@ ExtendWith (MockitoExtension .class )
35
33
public class ModelCacheTest {
34
+
36
35
private static final String GCS_BUCKET_NAME = "test_bucket" ;
37
36
private static final String MODEL_CACHE_KEY_PREFIX = "onnxModelRunner_" ;
38
37
private static final String PBUUID = "test-pbuid" ;
@@ -71,11 +70,11 @@ public void setUp() {
71
70
@ Test
72
71
public void getShouldReturnModelFromCacheWhenPresent () {
73
72
// given
74
- String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
73
+ final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
75
74
when (cache .getIfPresent (eq (cacheKey ))).thenReturn (onnxModelRunner );
76
75
77
76
// when
78
- Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
77
+ final Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
79
78
80
79
// then
81
80
assertThat (future .succeeded ()).isTrue ();
@@ -86,32 +85,19 @@ public void getShouldReturnModelFromCacheWhenPresent() {
86
85
@ Test
87
86
public void getShouldSkipFetchingWhenFetchingInProgress () throws NoSuchFieldException , IllegalAccessException {
88
87
// given
89
- String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
88
+ final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
90
89
91
- // Create a spy of the ModelCache class
92
- ModelCache spyModelCache = spy ( target );
90
+ final ModelCache spyModelCache = spy ( target );
91
+ final AtomicBoolean mockFetchingState = mock ( AtomicBoolean . class );
93
92
94
- // Mock the cache to simulate that the model is not present
95
93
when (cache .getIfPresent (eq (cacheKey ))).thenReturn (null );
96
-
97
- // Spy the isFetching AtomicBoolean behavior
98
- AtomicBoolean mockFetchingState = mock (AtomicBoolean .class );
99
-
100
- // Mock fetching state for 2 calls
101
94
when (mockFetchingState .compareAndSet (false , true )).thenReturn (false );
102
-
103
- // Use reflection to set the private field 'isFetching' in the spy accessible
104
- Field isFetchingField = ModelCache .class .getDeclaredField ("isFetching" );
95
+ final Field isFetchingField = ModelCache .class .getDeclaredField ("isFetching" );
105
96
isFetchingField .setAccessible (true );
106
97
isFetchingField .set (spyModelCache , mockFetchingState );
107
98
108
99
// when
109
- Future <OnnxModelRunner > result = spyModelCache .get (ONNX_MODEL_PATH , PBUUID );
110
-
111
- System .out .println (
112
- "firstCall.cause().getMessage(): " + result .cause ().getMessage () + "\n " +
113
- "firstCall.succeeded(): " + result .succeeded ()
114
- );
100
+ final Future <OnnxModelRunner > result = spyModelCache .get (ONNX_MODEL_PATH , PBUUID );
115
101
116
102
// then
117
103
assertThat (result .failed ()).isTrue ();
@@ -132,18 +118,10 @@ public void getShouldFetchModelWhenNotInCache() throws OrtException {
132
118
lenient ().when (onnxModelRunnerFactory .create (bytes )).thenReturn (onnxModelRunner );
133
119
134
120
// when
135
- Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
121
+ final Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
136
122
137
123
// then
138
124
future .onComplete (ar -> {
139
-
140
- System .out .println (
141
- "future.onComplete: \n " +
142
- " ar: " + ar + "\n " +
143
- " ar.result(): " + ar .result () + "\n " +
144
- " cache: " + cache
145
- );
146
-
147
125
assertThat (ar .succeeded ()).isTrue ();
148
126
assertThat (ar .result ()).isEqualTo (onnxModelRunner );
149
127
verify (cache ).put (eq (cacheKey ), eq (onnxModelRunner ));
@@ -155,27 +133,14 @@ public void getShouldThrowExceptionWhenStorageFails() {
155
133
// given
156
134
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
157
135
158
- // Mock that the model is not in cache
159
136
when (cache .getIfPresent (eq (cacheKey ))).thenReturn (null );
160
-
161
- // Simulate an error when accessing the storage bucket
162
137
lenient ().when (storage .get (GCS_BUCKET_NAME )).thenThrow (new StorageException (500 , "Storage Error" ));
163
138
164
139
// when
165
- Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
140
+ final Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
166
141
167
142
// then
168
143
future .onComplete (ar -> {
169
-
170
- System .out .println (
171
- "future.onComplete: \n " +
172
- " ar: " + ar + "\n " +
173
- " ar.failed(): " + ar .failed () + "\n " +
174
- " ar.cause(): " + ar .cause () + "\n " +
175
- " ar.cause().getMessage(): " + ar .cause ().getMessage () + "\n " +
176
- " ar.result(): " + ar .result ()
177
- );
178
-
179
144
assertThat (ar .cause ()).isInstanceOf (PreBidException .class );
180
145
assertThat (ar .cause ().getMessage ()).contains ("Error accessing GCS artefact for model" );
181
146
});
@@ -187,31 +152,18 @@ public void getShouldThrowExceptionWhenOnnxModelFails() throws OrtException {
187
152
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
188
153
final byte [] bytes = new byte []{1 , 2 , 3 };
189
154
190
- // Mock that the model is not in cache
191
155
when (cache .getIfPresent (eq (cacheKey ))).thenReturn (null );
192
-
193
- // Simulate an error when accessing the storage bucket
194
- when (storage .get (GCS_BUCKET_NAME )).thenReturn (bucket );;
156
+ when (storage .get (GCS_BUCKET_NAME )).thenReturn (bucket );
195
157
lenient ().when (bucket .get (ONNX_MODEL_PATH )).thenReturn (blob );
196
158
lenient ().when (blob .getContent ()).thenReturn (bytes );
197
- lenient ().when (onnxModelRunnerFactory .create (bytes )).thenThrow (new OrtException ("Failed to convert blob to ONNX model" ));
159
+ lenient ().when (onnxModelRunnerFactory .create (bytes )).thenThrow (
160
+ new OrtException ("Failed to convert blob to ONNX model" ));
198
161
199
162
// when
200
- Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
163
+ final Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
201
164
202
165
// then
203
166
future .onComplete (ar -> {
204
-
205
- System .out .println (
206
- "future.onComplete: \n " +
207
- " ar: " + ar + "\n " +
208
- " ar.failed(): " + ar .failed () + "\n " +
209
- " ar.cause(): " + ar .cause () + "\n " +
210
- " ar.cause().getMessage(): " + ar .cause ().getMessage () + "\n " +
211
- " ar.result(): " + ar .result ()
212
- );
213
-
214
-
215
167
assertThat (ar .failed ()).isTrue ();
216
168
assertThat (ar .cause ()).isInstanceOf (PreBidException .class );
217
169
assertThat (ar .cause ().getMessage ()).contains ("Failed to convert blob to ONNX model" );
@@ -223,29 +175,16 @@ public void getShouldThrowExceptionWhenBucketNotFound() {
223
175
// given
224
176
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID ;
225
177
226
- // Mock that the model is not in cache
227
178
when (cache .getIfPresent (eq (cacheKey ))).thenReturn (null );
228
-
229
- // Simulate an error when accessing the storage bucket
230
- when (storage .get (GCS_BUCKET_NAME )).thenReturn (bucket );;
179
+ when (storage .get (GCS_BUCKET_NAME )).thenReturn (bucket );
231
180
lenient ().when (bucket .get (ONNX_MODEL_PATH )).thenReturn (blob );
232
181
lenient ().when (blob .getContent ()).thenThrow (new PreBidException ("Bucket not found" ));
233
182
234
183
// when
235
- Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
184
+ final Future <OnnxModelRunner > future = target .get (ONNX_MODEL_PATH , PBUUID );
236
185
237
186
// then
238
187
future .onComplete (ar -> {
239
-
240
- System .out .println (
241
- "future.onComplete: \n " +
242
- " ar: " + ar + "\n " +
243
- " ar.failed(): " + ar .failed () + "\n " +
244
- " ar.cause(): " + ar .cause () + "\n " +
245
- " ar.cause().getMessage(): " + ar .cause ().getMessage () + "\n " +
246
- " ar.result(): " + ar .result ()
247
- );
248
-
249
188
assertThat (ar .failed ()).isTrue ();
250
189
assertThat (ar .cause ()).isInstanceOf (PreBidException .class );
251
190
assertThat (ar .cause ().getMessage ()).contains ("Bucket not found" );
0 commit comments