Skip to content

Commit e87beca

Browse files
authored
Fix judgment handling for implicit judgments (#93)
* fix implicit judgment for the new nested data structure introduced in #77 Signed-off-by: wrigleyDan <[email protected]> * run ./gradlew spotlessApply Signed-off-by: wrigleyDan <[email protected]> * add entry to CHANGELOG.md Signed-off-by: wrigleyDan <[email protected]> * edit comment as per suggestion Signed-off-by: wrigleyDan <[email protected]> --------- Signed-off-by: wrigleyDan <[email protected]>
1 parent 6dc660e commit e87beca

File tree

7 files changed

+123
-39
lines changed

7 files changed

+123
-39
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
2929
- Extend hybrid search optimizer demo script to use models. ([#69](https://github.com/opensearch-project/search-relevance/pull/69))
3030
- Set limit for number of fields programmatically during index creation ([#74](https://github.com/opensearch-project/search-relevance/pull/74)
3131
- Change model for Judgment entity ([#77](https://github.com/opensearch-project/search-relevance/pull/77)
32+
- Fix judgment handling for implicit judgments to be aligned with data model for Judgment again ([#93](https://github.com/opensearch-project/search-relevance/pull/93)
3233

3334
### Security

src/main/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesManager.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
*/
88
package org.opensearch.searchrelevance.indices;
99

10-
import lombok.Builder;
11-
import lombok.Getter;
12-
import lombok.extern.log4j.Log4j2;
10+
import java.io.FileNotFoundException;
11+
import java.io.IOException;
12+
import java.io.InputStream;
13+
import java.util.Objects;
14+
import java.util.function.BiConsumer;
15+
1316
import org.opensearch.ResourceAlreadyExistsException;
1417
import org.opensearch.ResourceNotFoundException;
1518
import org.opensearch.action.DocWriteRequest.OpType;
@@ -34,13 +37,11 @@
3437
import org.opensearch.searchrelevance.exception.SearchRelevanceException;
3538
import org.opensearch.searchrelevance.shared.StashedThreadContext;
3639
import org.opensearch.transport.client.Client;
37-
import reactor.util.annotation.NonNull;
3840

39-
import java.io.FileNotFoundException;
40-
import java.io.IOException;
41-
import java.io.InputStream;
42-
import java.util.Objects;
43-
import java.util.function.BiConsumer;
41+
import lombok.Builder;
42+
import lombok.Getter;
43+
import lombok.extern.log4j.Log4j2;
44+
import reactor.util.annotation.NonNull;
4445

4546
/**
4647
* Manager for common search relevance system indices actions.

src/main/java/org/opensearch/searchrelevance/judgments/UbiJudgmentsProcessor.java

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88
package org.opensearch.searchrelevance.judgments;
99

10+
import java.util.ArrayList;
11+
import java.util.HashMap;
1012
import java.util.List;
1113
import java.util.Map;
1214

@@ -52,7 +54,87 @@ public void generateJudgmentRating(Map<String, Object> metadata, ActionListener<
5254
coecClickModel.calculateJudgments(new ActionListener<>() {
5355
@Override
5456
public void onResponse(List<Map<String, Object>> judgments) {
55-
listener.onResponse(judgments);
57+
// Create the result map in the expected format
58+
List<Map<String, Object>> formattedRatings = new ArrayList<>();
59+
for (Map<String, Object> queryJudgment : judgments) {
60+
String queryText = (String) queryJudgment.get("query");
61+
Object ratingData = queryJudgment.get("ratings");
62+
63+
if (!(ratingData instanceof Map)) {
64+
listener.onFailure(
65+
new SearchRelevanceException(
66+
"queryText " + queryText + " must have rating data as a Map.",
67+
RestStatus.BAD_REQUEST
68+
)
69+
);
70+
return;
71+
}
72+
73+
@SuppressWarnings("unchecked")
74+
Map<String, Object> ratingsMap = (Map<String, Object>) ratingData; // Cast to Map, not List
75+
76+
// Prepare a list to hold the docId and score maps for the current query
77+
List<Map<String, String>> docIdScoreList = new ArrayList<>();
78+
79+
// Iterate over the entrySet of the HashMap ***
80+
for (Map.Entry<String, Object> entry : ratingsMap.entrySet()) {
81+
String docId = entry.getKey(); // The key is the docId
82+
Object ratingObject = entry.getValue(); // The value is the rating
83+
84+
if (docId == null || docId.isEmpty()) {
85+
// This case is unlikely if the keys of the map are docIds, but good for defensive coding
86+
listener.onFailure(
87+
new SearchRelevanceException(
88+
"docId (map key) for queryText " + queryText + " must not be null or empty",
89+
RestStatus.BAD_REQUEST
90+
)
91+
);
92+
return;
93+
}
94+
if (ratingObject == null) {
95+
listener.onFailure(
96+
new SearchRelevanceException(
97+
"rating for docId '" + docId + "' in queryText " + queryText + " must not be null",
98+
RestStatus.BAD_REQUEST
99+
)
100+
);
101+
return;
102+
}
103+
104+
String rating = String.valueOf(ratingObject); // Convert rating to String
105+
106+
try {
107+
Float.parseFloat(rating);
108+
} catch (NumberFormatException e) {
109+
listener.onFailure(
110+
new SearchRelevanceException(
111+
"rating '"
112+
+ rating
113+
+ "' for docId '"
114+
+ docId
115+
+ "' in queryText "
116+
+ queryText
117+
+ " must be a valid float",
118+
RestStatus.BAD_REQUEST
119+
)
120+
);
121+
return;
122+
}
123+
124+
// Add the docId and score to the list for the current query
125+
Map<String, String> docScoreMap = new HashMap<>();
126+
docScoreMap.put("docId", docId);
127+
docScoreMap.put("score", rating);
128+
docIdScoreList.add(docScoreMap);
129+
}
130+
131+
// Add the formatted ratings for this query
132+
Map<String, Object> queryRatings = new HashMap<>();
133+
queryRatings.put("query", queryText);
134+
queryRatings.put("ratings", docIdScoreList);
135+
formattedRatings.add(queryRatings);
136+
}
137+
listener.onResponse(formattedRatings);
56138
}
57139

58140
@Override

src/main/java/org/opensearch/searchrelevance/ml/TokenizerUtil.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
*/
88
package org.opensearch.searchrelevance.ml;
99

10+
import java.util.ArrayList;
11+
import java.util.List;
12+
1013
import com.knuddels.jtokkit.Encodings;
1114
import com.knuddels.jtokkit.api.Encoding;
1215
import com.knuddels.jtokkit.api.EncodingRegistry;
1316
import com.knuddels.jtokkit.api.EncodingType;
1417
import com.knuddels.jtokkit.api.IntArrayList;
1518
import com.knuddels.jtokkit.api.ModelType;
1619

17-
import java.util.ArrayList;
18-
import java.util.List;
19-
2020
/**
2121
* For OpenAI models, use their official tiktoken library - https://github.com/knuddelsgmbh/jtokkit
2222
*/

src/main/java/org/opensearch/searchrelevance/transport/queryset/PostQuerySetTransportAction.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
*/
88
package org.opensearch.searchrelevance.transport.queryset;
99

10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
import java.util.UUID;
14+
import java.util.concurrent.ExecutionException;
15+
import java.util.stream.Collectors;
16+
1017
import org.opensearch.action.index.IndexResponse;
1118
import org.opensearch.action.support.ActionFilters;
1219
import org.opensearch.action.support.HandledTransportAction;
@@ -24,13 +31,6 @@
2431
import org.opensearch.transport.TransportService;
2532
import org.opensearch.transport.client.Client;
2633

27-
import java.util.HashMap;
28-
import java.util.List;
29-
import java.util.Map;
30-
import java.util.UUID;
31-
import java.util.concurrent.ExecutionException;
32-
import java.util.stream.Collectors;
33-
3434
public class PostQuerySetTransportAction extends HandledTransportAction<PostQuerySetRequest, IndexResponse> {
3535
private final Client client;
3636
private final ClusterService clusterService;

src/main/java/org/opensearch/searchrelevance/transport/queryset/PutQuerySetTransportAction.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
*/
88
package org.opensearch.searchrelevance.transport.queryset;
99

10+
import static org.opensearch.searchrelevance.model.QueryWithReference.DELIMITER;
11+
12+
import java.util.List;
13+
import java.util.UUID;
14+
import java.util.stream.Collectors;
15+
1016
import org.opensearch.action.index.IndexResponse;
1117
import org.opensearch.action.support.ActionFilters;
1218
import org.opensearch.action.support.HandledTransportAction;
@@ -23,12 +29,6 @@
2329
import org.opensearch.tasks.Task;
2430
import org.opensearch.transport.TransportService;
2531

26-
import java.util.List;
27-
import java.util.UUID;
28-
import java.util.stream.Collectors;
29-
30-
import static org.opensearch.searchrelevance.model.QueryWithReference.DELIMITER;
31-
3232
public class PutQuerySetTransportAction extends HandledTransportAction<PutQuerySetRequest, IndexResponse> {
3333
private final ClusterService clusterService;
3434
private final QuerySetDao querySetDao;

src/test/java/org/opensearch/searchrelevance/indices/SearchRelevanceIndicesManagerTests.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
*/
88
package org.opensearch.searchrelevance.indices;
99

10+
import static org.mockito.ArgumentMatchers.any;
11+
import static org.mockito.Mockito.doAnswer;
12+
import static org.mockito.Mockito.mock;
13+
import static org.mockito.Mockito.never;
14+
import static org.mockito.Mockito.verify;
15+
import static org.mockito.Mockito.when;
16+
import static org.opensearch.searchrelevance.indices.SearchRelevanceIndices.QUERY_SET;
17+
18+
import java.io.IOException;
19+
import java.util.HashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
1023
import org.apache.lucene.search.TotalHits;
1124
import org.mockito.ArgumentCaptor;
1225
import org.mockito.Mock;
@@ -48,19 +61,6 @@
4861
import org.opensearch.transport.client.Client;
4962
import org.opensearch.transport.client.IndicesAdminClient;
5063

51-
import java.io.IOException;
52-
import java.util.HashMap;
53-
import java.util.List;
54-
import java.util.Map;
55-
56-
import static org.mockito.ArgumentMatchers.any;
57-
import static org.mockito.Mockito.doAnswer;
58-
import static org.mockito.Mockito.mock;
59-
import static org.mockito.Mockito.never;
60-
import static org.mockito.Mockito.verify;
61-
import static org.mockito.Mockito.when;
62-
import static org.opensearch.searchrelevance.indices.SearchRelevanceIndices.QUERY_SET;
63-
6464
public class SearchRelevanceIndicesManagerTests extends OpenSearchTestCase {
6565
@Mock
6666
private Client client;

0 commit comments

Comments
 (0)