|
1 | 1 | package org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor;
|
2 | 2 |
|
3 | 3 | import ai.onnxruntime.OnnxTensor;
|
| 4 | +import ai.onnxruntime.OnnxValue; |
4 | 5 | import ai.onnxruntime.OrtException;
|
5 | 6 | import ai.onnxruntime.OrtSession;
|
6 | 7 | import org.prebid.server.exception.PreBidException;
|
@@ -49,14 +50,42 @@ private Map<String, Map<String, Boolean>> processModelResults(
|
49 | 50 | List<ThrottlingMessage> throttlingMessages,
|
50 | 51 | Double threshold) {
|
51 | 52 |
|
| 53 | + validateThrottlingMessages(throttlingMessages); |
| 54 | + |
52 | 55 | return StreamSupport.stream(results.spliterator(), false)
|
53 |
| - .filter(onnxItem -> Objects.equals(onnxItem.getKey(), "probabilities")) |
| 56 | + .filter(onnxItem -> { |
| 57 | + validateOnnxTensor(onnxItem); |
| 58 | + return Objects.equals(onnxItem.getKey(), "probabilities"); |
| 59 | + }) |
54 | 60 | .map(onnxItem -> (OnnxTensor) onnxItem.getValue())
|
55 |
| - .map(tensor -> extractAndProcessProbabilities(tensor, throttlingMessages, threshold)) |
| 61 | + .map(tensor -> { |
| 62 | + validateTensorSize(tensor, throttlingMessages.size()); |
| 63 | + return extractAndProcessProbabilities(tensor, throttlingMessages, threshold); |
| 64 | + }) |
56 | 65 | .flatMap(map -> map.entrySet().stream())
|
57 | 66 | .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
58 | 67 | }
|
59 | 68 |
|
| 69 | + private void validateThrottlingMessages(List<ThrottlingMessage> throttlingMessages) { |
| 70 | + if (throttlingMessages == null || throttlingMessages.isEmpty()) { |
| 71 | + throw new PreBidException("throttlingMessages cannot be null or empty"); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + private void validateOnnxTensor(Map.Entry<String, OnnxValue> onnxItem) { |
| 76 | + if (!(onnxItem.getValue() instanceof OnnxTensor)) { |
| 77 | + throw new PreBidException("Expected OnnxTensor for 'probabilities', but found: " |
| 78 | + + onnxItem.getValue().getClass().getName()); |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + private void validateTensorSize(OnnxTensor tensor, int expectedSize) { |
| 83 | + final long[] tensorShape = tensor.getInfo().getShape(); |
| 84 | + if (tensorShape.length == 0 || tensorShape[0] != expectedSize) { |
| 85 | + throw new PreBidException("Mismatch between tensor size and throttlingMessages size"); |
| 86 | + } |
| 87 | + } |
| 88 | + |
60 | 89 | private Map<String, Map<String, Boolean>> extractAndProcessProbabilities(
|
61 | 90 | OnnxTensor tensor,
|
62 | 91 | List<ThrottlingMessage> throttlingMessages,
|
|
0 commit comments