Skip to content

Commit 97fb751

Browse files
committed
fix review9: OnnxModelRunnerTest
1 parent 346281a commit 97fb751

File tree

8 files changed

+142
-249
lines changed

8 files changed

+142
-249
lines changed

extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/core/ThrottlingThresholdsFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import java.io.IOException;
77

88
public class ThrottlingThresholdsFactory {
9+
910
public ThrottlingThresholds create(byte[] bytes, ObjectMapper mapper) throws IOException {
10-
JsonNode thresholdsJsonNode = mapper.readTree(bytes);
11+
final JsonNode thresholdsJsonNode = mapper.readTree(bytes);
1112
return mapper.treeToValue(thresholdsJsonNode, ThrottlingThresholds.class);
1213
}
1314
}

extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/ModelCache.java

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import ai.onnxruntime.OrtException;
44
import com.github.benmanes.caffeine.cache.Cache;
55
import com.google.cloud.storage.Blob;
6-
import com.google.cloud.storage.Bucket;
76
import com.google.cloud.storage.Storage;
87
import com.google.cloud.storage.StorageException;
98
import io.vertx.core.Future;
@@ -12,7 +11,6 @@
1211
import org.prebid.server.log.Logger;
1312
import org.prebid.server.log.LoggerFactory;
1413

15-
import java.util.Arrays;
1614
import java.util.Objects;
1715
import java.util.Optional;
1816
import java.util.concurrent.atomic.AtomicBoolean;
@@ -59,11 +57,6 @@ public Future<OnnxModelRunner> get(String onnxModelPath, String pbuid) {
5957
return Future.succeededFuture(cachedOnnxModelRunner);
6058
}
6159

62-
System.out.println(
63-
"get: " + "\n" +
64-
" isFetching: " + isFetching
65-
);
66-
6760
if (isFetching.compareAndSet(false, true)) {
6861
try {
6962
return fetchAndCacheModelRunner(onnxModelPath, cacheKey);
@@ -75,13 +68,7 @@ public Future<OnnxModelRunner> get(String onnxModelPath, String pbuid) {
7568
return Future.failedFuture("ModelRunner fetching in progress. Skip current request");
7669
}
7770

78-
private Future<OnnxModelRunner> fetchAndCacheModelRunner(String onnxModelPath, String cacheKey) {
79-
System.out.println(
80-
"fetchAndCacheModelRunner: \n" +
81-
" onnxModelPath: " + onnxModelPath + "\n" +
82-
" cacheKey: " + cacheKey
83-
);
84-
71+
private Future<OnnxModelRunner> fetchAndCacheModelRunner(String onnxModelPath, String cacheKey) {
8572
return vertx.executeBlocking(() -> getBlob(onnxModelPath))
8673
.map(this::loadModelRunner)
8774
.onSuccess(onnxModelRunner -> cache.put(cacheKey, onnxModelRunner))
@@ -94,24 +81,15 @@ private Blob getBlob(String onnxModelPath) {
9481
.map(bucket -> bucket.get(onnxModelPath))
9582
.orElseThrow(() -> new PreBidException("Bucket not found: " + gcsBucketName));
9683
} catch (StorageException e) {
97-
System.out.println("StorageException trigger PreBidException");
9884
throw new PreBidException("Error accessing GCS artefact for model: ", e);
9985
}
10086
}
10187

10288
private OnnxModelRunner loadModelRunner(Blob blob) {
10389
try {
10490
final byte[] onnxModelBytes = blob.getContent();
105-
106-
System.out.println(
107-
"loadModelRunner: \n" +
108-
" blob: " + blob + "\n" +
109-
" onnxModelBytes: " + Arrays.toString(onnxModelBytes) + "\n"
110-
);
111-
11291
return onnxModelRunnerFactory.create(onnxModelBytes);
11392
} catch (OrtException e) {
114-
System.out.println("OrtException trigger PreBidException");
11593
throw new PreBidException("Failed to convert blob to ONNX model", e);
11694
}
11795
}

extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/OnnxModelRunnerFactory.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ai.onnxruntime.OrtException;
44

55
public class OnnxModelRunnerFactory {
6+
67
public OnnxModelRunner create(byte[] bytes) throws OrtException {
78
return new OnnxModelRunner(bytes);
89
}

extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/ThresholdCache.java

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor;
22

3-
import com.fasterxml.jackson.databind.JsonNode;
43
import com.fasterxml.jackson.databind.ObjectMapper;
54
import com.github.benmanes.caffeine.cache.Cache;
65
import com.google.cloud.storage.Blob;
@@ -15,7 +14,6 @@
1514
import org.prebid.server.log.LoggerFactory;
1615

1716
import java.io.IOException;
18-
import java.util.Arrays;
1917
import java.util.Objects;
2018
import java.util.Optional;
2119
import java.util.concurrent.atomic.AtomicBoolean;
@@ -78,11 +76,6 @@ public Future<ThrottlingThresholds> get(String thresholdJsonPath, String pbuid)
7876
}
7977

8078
private Future<ThrottlingThresholds> fetchAndCacheThrottlingThresholds(String thresholdJsonPath, String cacheKey) {
81-
System.out.println(
82-
"fetchAndCacheThrottlingThresholds: \n" +
83-
" thresholdJsonPath: " + thresholdJsonPath
84-
);
85-
8679
return vertx.executeBlocking(() -> getBlob(thresholdJsonPath))
8780
.map(this::loadThrottlingThresholds)
8881
.onSuccess(thresholds -> cache.put(cacheKey, thresholds))
@@ -91,11 +84,6 @@ private Future<ThrottlingThresholds> fetchAndCacheThrottlingThresholds(String th
9184

9285
private Blob getBlob(String thresholdJsonPath) {
9386
try {
94-
System.out.println(
95-
"getBlob: \n" +
96-
" thresholdJsonPath: " + thresholdJsonPath + "\n"
97-
);
98-
9987
return Optional.ofNullable(storage.get(gcsBucketName))
10088
.map(bucket -> bucket.get(thresholdJsonPath))
10189
.orElseThrow(() -> new PreBidException("Bucket not found: " + gcsBucketName));
@@ -105,22 +93,9 @@ private Blob getBlob(String thresholdJsonPath) {
10593
}
10694

10795
private ThrottlingThresholds loadThrottlingThresholds(Blob blob) {
108-
final JsonNode thresholdsJsonNode;
10996
try {
11097
final byte[] jsonBytes = blob.getContent();
111-
// thresholdsJsonNode = mapper.readTree(jsonBytes);
112-
// ThrottlingThresholds tempThrottlingThresholds = mapper.treeToValue(thresholdsJsonNode, ThrottlingThresholds.class);
113-
ThrottlingThresholds tempThrottlingThresholds = throttlingThresholdsFactory.create(jsonBytes, mapper);
114-
115-
System.out.println(
116-
"loadThrottlingThresholds: \n" +
117-
" blob: " + blob + "\n" +
118-
" jsonBytes: " + Arrays.toString(jsonBytes) + "\n" +
119-
//" thresholdsJsonNode: " + thresholdsJsonNode + "\n" +
120-
" tempThrottlingThresholds: " + tempThrottlingThresholds + "\n"
121-
);
122-
123-
return tempThrottlingThresholds;
98+
return throttlingThresholdsFactory.create(jsonBytes, mapper);
12499
} catch (IOException e) {
125100
throw new PreBidException("Failed to load throttling thresholds json", e);
126101
}

extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/GreenbidsInvocationServiceTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ public void createGreenbidsInvocationResultShouldReturnUpdateBidRequestWhenNotEx
6262
partner, bidRequest, impsBiddersFilterMap);
6363

6464
// then
65-
JsonNode updatedBidRequestExtPrebidBidders = result.getUpdatedBidRequest().getImp().getFirst().getExt()
65+
final JsonNode updatedBidRequestExtPrebidBidders = result.getUpdatedBidRequest().getImp().getFirst().getExt()
6666
.get("prebid").get("bidder");
67-
Ortb2ImpExtResult ortb2ImpExtResult = result.getAnalyticsResult().getValues().get("adunitcodevalue");
68-
Map<String, Boolean> keptInAuction = ortb2ImpExtResult.getGreenbids().getKeptInAuction();
67+
final Ortb2ImpExtResult ortb2ImpExtResult = result.getAnalyticsResult().getValues().get("adunitcodevalue");
68+
final Map<String, Boolean> keptInAuction = ortb2ImpExtResult.getGreenbids().getKeptInAuction();
6969

7070
assertThat(result.getInvocationAction()).isEqualTo(InvocationAction.update);
7171
assertThat(updatedBidRequestExtPrebidBidders.has("rubicon")).isTrue();
@@ -99,10 +99,10 @@ public void createGreenbidsInvocationResultShouldReturnNoActionWhenExploration()
9999
partner, bidRequest, impsBiddersFilterMap);
100100

101101
// then
102-
JsonNode updatedBidRequestExtPrebidBidders = result.getUpdatedBidRequest().getImp().getFirst().getExt()
102+
final JsonNode updatedBidRequestExtPrebidBidders = result.getUpdatedBidRequest().getImp().getFirst().getExt()
103103
.get("prebid").get("bidder");
104-
Ortb2ImpExtResult ortb2ImpExtResult = result.getAnalyticsResult().getValues().get("adunitcodevalue");
105-
Map<String, Boolean> keptInAuction = ortb2ImpExtResult.getGreenbids().getKeptInAuction();
104+
final Ortb2ImpExtResult ortb2ImpExtResult = result.getAnalyticsResult().getValues().get("adunitcodevalue");
105+
final Map<String, Boolean> keptInAuction = ortb2ImpExtResult.getGreenbids().getKeptInAuction();
106106

107107
assertThat(result.getInvocationAction()).isEqualTo(InvocationAction.no_action);
108108
assertThat(updatedBidRequestExtPrebidBidders.has("rubicon")).isTrue();

extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/ModelCacheTest.java

Lines changed: 16 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
import io.vertx.core.Future;
99
import io.vertx.core.Vertx;
1010
import com.google.cloud.storage.Blob;
11-
1211
import org.junit.jupiter.api.BeforeEach;
1312
import org.junit.jupiter.api.extension.ExtendWith;
1413
import org.junit.jupiter.api.Test;
1514
import org.mockito.Mock;
16-
import org.mockito.Mockito;
1715
import org.mockito.junit.jupiter.MockitoExtension;
1816
import org.prebid.server.exception.PreBidException;
1917
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ModelCache;
@@ -33,6 +31,7 @@
3331

3432
@ExtendWith(MockitoExtension.class)
3533
public class ModelCacheTest {
34+
3635
private static final String GCS_BUCKET_NAME = "test_bucket";
3736
private static final String MODEL_CACHE_KEY_PREFIX = "onnxModelRunner_";
3837
private static final String PBUUID = "test-pbuid";
@@ -71,11 +70,11 @@ public void setUp() {
7170
@Test
7271
public void getShouldReturnModelFromCacheWhenPresent() {
7372
// given
74-
String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
73+
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
7574
when(cache.getIfPresent(eq(cacheKey))).thenReturn(onnxModelRunner);
7675

7776
// when
78-
Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
77+
final Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
7978

8079
// then
8180
assertThat(future.succeeded()).isTrue();
@@ -86,32 +85,19 @@ public void getShouldReturnModelFromCacheWhenPresent() {
8685
@Test
8786
public void getShouldSkipFetchingWhenFetchingInProgress() throws NoSuchFieldException, IllegalAccessException {
8887
// given
89-
String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
88+
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
9089

91-
// Create a spy of the ModelCache class
92-
ModelCache spyModelCache = spy(target);
90+
final ModelCache spyModelCache = spy(target);
91+
final AtomicBoolean mockFetchingState = mock(AtomicBoolean.class);
9392

94-
// Mock the cache to simulate that the model is not present
9593
when(cache.getIfPresent(eq(cacheKey))).thenReturn(null);
96-
97-
// Spy the isFetching AtomicBoolean behavior
98-
AtomicBoolean mockFetchingState = mock(AtomicBoolean.class);
99-
100-
// Mock fetching state for 2 calls
10194
when(mockFetchingState.compareAndSet(false, true)).thenReturn(false);
102-
103-
// Use reflection to set the private field 'isFetching' in the spy accessible
104-
Field isFetchingField = ModelCache.class.getDeclaredField("isFetching");
95+
final Field isFetchingField = ModelCache.class.getDeclaredField("isFetching");
10596
isFetchingField.setAccessible(true);
10697
isFetchingField.set(spyModelCache, mockFetchingState);
10798

10899
// when
109-
Future<OnnxModelRunner> result = spyModelCache.get(ONNX_MODEL_PATH, PBUUID);
110-
111-
System.out.println(
112-
"firstCall.cause().getMessage(): " + result.cause().getMessage() + "\n" +
113-
"firstCall.succeeded(): " + result.succeeded()
114-
);
100+
final Future<OnnxModelRunner> result = spyModelCache.get(ONNX_MODEL_PATH, PBUUID);
115101

116102
// then
117103
assertThat(result.failed()).isTrue();
@@ -132,18 +118,10 @@ public void getShouldFetchModelWhenNotInCache() throws OrtException {
132118
lenient().when(onnxModelRunnerFactory.create(bytes)).thenReturn(onnxModelRunner);
133119

134120
// when
135-
Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
121+
final Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
136122

137123
// then
138124
future.onComplete(ar -> {
139-
140-
System.out.println(
141-
"future.onComplete: \n" +
142-
" ar: " + ar + "\n" +
143-
" ar.result(): " + ar.result() + "\n" +
144-
" cache: " + cache
145-
);
146-
147125
assertThat(ar.succeeded()).isTrue();
148126
assertThat(ar.result()).isEqualTo(onnxModelRunner);
149127
verify(cache).put(eq(cacheKey), eq(onnxModelRunner));
@@ -155,27 +133,14 @@ public void getShouldThrowExceptionWhenStorageFails() {
155133
// given
156134
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
157135

158-
// Mock that the model is not in cache
159136
when(cache.getIfPresent(eq(cacheKey))).thenReturn(null);
160-
161-
// Simulate an error when accessing the storage bucket
162137
lenient().when(storage.get(GCS_BUCKET_NAME)).thenThrow(new StorageException(500, "Storage Error"));
163138

164139
// when
165-
Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
140+
final Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
166141

167142
// then
168143
future.onComplete(ar -> {
169-
170-
System.out.println(
171-
"future.onComplete: \n" +
172-
" ar: " + ar + "\n" +
173-
" ar.failed(): " + ar.failed() + "\n" +
174-
" ar.cause(): " + ar.cause() + "\n" +
175-
" ar.cause().getMessage(): " + ar.cause().getMessage() + "\n" +
176-
" ar.result(): " + ar.result()
177-
);
178-
179144
assertThat(ar.cause()).isInstanceOf(PreBidException.class);
180145
assertThat(ar.cause().getMessage()).contains("Error accessing GCS artefact for model");
181146
});
@@ -187,31 +152,18 @@ public void getShouldThrowExceptionWhenOnnxModelFails() throws OrtException {
187152
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
188153
final byte[] bytes = new byte[]{1, 2, 3};
189154

190-
// Mock that the model is not in cache
191155
when(cache.getIfPresent(eq(cacheKey))).thenReturn(null);
192-
193-
// Simulate an error when accessing the storage bucket
194-
when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket);;
156+
when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket);
195157
lenient().when(bucket.get(ONNX_MODEL_PATH)).thenReturn(blob);
196158
lenient().when(blob.getContent()).thenReturn(bytes);
197-
lenient().when(onnxModelRunnerFactory.create(bytes)).thenThrow(new OrtException("Failed to convert blob to ONNX model"));
159+
lenient().when(onnxModelRunnerFactory.create(bytes)).thenThrow(
160+
new OrtException("Failed to convert blob to ONNX model"));
198161

199162
// when
200-
Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
163+
final Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
201164

202165
// then
203166
future.onComplete(ar -> {
204-
205-
System.out.println(
206-
"future.onComplete: \n" +
207-
" ar: " + ar + "\n" +
208-
" ar.failed(): " + ar.failed() + "\n" +
209-
" ar.cause(): " + ar.cause() + "\n" +
210-
" ar.cause().getMessage(): " + ar.cause().getMessage() + "\n" +
211-
" ar.result(): " + ar.result()
212-
);
213-
214-
215167
assertThat(ar.failed()).isTrue();
216168
assertThat(ar.cause()).isInstanceOf(PreBidException.class);
217169
assertThat(ar.cause().getMessage()).contains("Failed to convert blob to ONNX model");
@@ -223,29 +175,16 @@ public void getShouldThrowExceptionWhenBucketNotFound() {
223175
// given
224176
final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID;
225177

226-
// Mock that the model is not in cache
227178
when(cache.getIfPresent(eq(cacheKey))).thenReturn(null);
228-
229-
// Simulate an error when accessing the storage bucket
230-
when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket);;
179+
when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket);
231180
lenient().when(bucket.get(ONNX_MODEL_PATH)).thenReturn(blob);
232181
lenient().when(blob.getContent()).thenThrow(new PreBidException("Bucket not found"));
233182

234183
// when
235-
Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
184+
final Future<OnnxModelRunner> future = target.get(ONNX_MODEL_PATH, PBUUID);
236185

237186
// then
238187
future.onComplete(ar -> {
239-
240-
System.out.println(
241-
"future.onComplete: \n" +
242-
" ar: " + ar + "\n" +
243-
" ar.failed(): " + ar.failed() + "\n" +
244-
" ar.cause(): " + ar.cause() + "\n" +
245-
" ar.cause().getMessage(): " + ar.cause().getMessage() + "\n" +
246-
" ar.result(): " + ar.result()
247-
);
248-
249188
assertThat(ar.failed()).isTrue();
250189
assertThat(ar.cause()).isInstanceOf(PreBidException.class);
251190
assertThat(ar.cause().getMessage()).contains("Bucket not found");

0 commit comments

Comments
 (0)