5
5
6
6
package org .opensearch .ml .engine .algorithms .question_answering ;
7
7
8
+ import static org .opensearch .ml .engine .ModelHelper .PYTORCH_ENGINE ;
8
9
import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .DEFAULT_WARMUP_CONTEXT ;
9
10
import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .DEFAULT_WARMUP_QUESTION ;
11
+ import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .FIELD_HIGHLIGHTS ;
12
+ import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .FIELD_POSITION ;
13
+ import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY ;
14
+ import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING ;
10
15
import static org .opensearch .ml .engine .algorithms .question_answering .QAConstants .SENTENCE_HIGHLIGHTING_TYPE ;
11
16
12
17
import java .util .ArrayList ;
18
+ import java .util .HashMap ;
13
19
import java .util .List ;
20
+ import java .util .Map ;
14
21
15
22
import org .opensearch .ml .common .FunctionName ;
16
23
import org .opensearch .ml .common .dataset .MLInputDataset ;
17
24
import org .opensearch .ml .common .dataset .QuestionAnsweringInputDataSet ;
18
25
import org .opensearch .ml .common .input .MLInput ;
19
26
import org .opensearch .ml .common .model .MLModelConfig ;
27
+ import org .opensearch .ml .common .output .model .ModelTensor ;
20
28
import org .opensearch .ml .common .output .model .ModelTensorOutput ;
21
29
import org .opensearch .ml .common .output .model .ModelTensors ;
22
30
import org .opensearch .ml .engine .algorithms .DLModel ;
23
31
import org .opensearch .ml .engine .annotation .Function ;
24
32
33
+ import ai .djl .huggingface .tokenizers .Encoding ;
25
34
import ai .djl .inference .Predictor ;
26
35
import ai .djl .modality .Input ;
27
36
import ai .djl .modality .Output ;
28
37
import ai .djl .translate .TranslateException ;
29
38
import ai .djl .translate .Translator ;
30
39
import ai .djl .translate .TranslatorFactory ;
40
+ import lombok .AllArgsConstructor ;
41
+ import lombok .Builder ;
42
+ import lombok .NoArgsConstructor ;
31
43
import lombok .extern .log4j .Log4j2 ;
32
44
33
45
/**
36
48
*/
37
49
@ Log4j2
38
50
@ Function (FunctionName .QUESTION_ANSWERING )
51
+ @ Builder (toBuilder = true )
52
+ @ AllArgsConstructor
53
+ @ NoArgsConstructor
39
54
public class QuestionAnsweringModel extends DLModel {
55
+ private MLModelConfig modelConfig ;
56
+ private Translator <Input , Output > translator ;
40
57
41
58
@ Override
42
59
public void warmUp (Predictor predictor , String modelId , MLModelConfig modelConfig ) throws TranslateException {
@@ -47,50 +64,193 @@ public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfi
47
64
throw new IllegalArgumentException ("model id is null" );
48
65
}
49
66
67
+ // Initialize model config from model if it exists, the model config field is required for sentence highlighting model.
68
+ if (modelConfig != null ) {
69
+ this .modelConfig = modelConfig ;
70
+ }
71
+
50
72
// Create input for the predictor
51
73
Input input = new Input ();
52
- input .add (DEFAULT_WARMUP_QUESTION );
53
- input .add (DEFAULT_WARMUP_CONTEXT );
74
+
75
+ if (isSentenceHighlightingModel ()) {
76
+ input .add (MLInput .QUESTION_FIELD , DEFAULT_WARMUP_QUESTION );
77
+ input .add (MLInput .CONTEXT_FIELD , DEFAULT_WARMUP_CONTEXT );
78
+ // Add initial chunk number key value pair which is needed for sentence highlighting model
79
+ input .add (HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY , HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING );
80
+ } else {
81
+ input .add (DEFAULT_WARMUP_QUESTION );
82
+ input .add (DEFAULT_WARMUP_CONTEXT );
83
+ }
54
84
55
85
// Run prediction to warm up the model
56
86
predictor .predict (input );
57
87
}
58
88
59
- /**
60
- * Checks if the model is configured for sentence highlighting.
61
- *
62
- * @param modelConfig The model configuration
63
- * @return true if the model is configured for sentence highlighting, false otherwise
64
- */
65
- private boolean isSentenceHighlightingType (MLModelConfig modelConfig ) {
66
- if (modelConfig != null ) {
67
- return SENTENCE_HIGHLIGHTING_TYPE .equalsIgnoreCase (modelConfig .getModelType ());
68
- }
69
- return false ;
70
- }
71
-
72
89
@ Override
73
90
public ModelTensorOutput predict (String modelId , MLInput mlInput ) throws TranslateException {
74
91
MLInputDataset inputDataSet = mlInput .getInputDataset ();
75
- List <ModelTensors > tensorOutputs = new ArrayList <>();
76
- Output output ;
77
92
QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet ) inputDataSet ;
78
93
String question = qaInputDataSet .getQuestion ();
79
94
String context = qaInputDataSet .getContext ();
95
+
96
+ if (isSentenceHighlightingModel ()) {
97
+ return predictSentenceHighlightingQA (question , context );
98
+ }
99
+
100
+ return predictStandardQA (question , context );
101
+ }
102
+
103
+ private boolean isSentenceHighlightingModel () {
104
+ return modelConfig != null && SENTENCE_HIGHLIGHTING_TYPE .equalsIgnoreCase (modelConfig .getModelType ());
105
+ }
106
+
107
+ private ModelTensorOutput predictStandardQA (String question , String context ) throws TranslateException {
80
108
Input input = new Input ();
81
109
input .add (question );
82
110
input .add (context );
83
- output = getPredictor ().predict (input );
84
- tensorOutputs .add (parseModelTensorOutput (output , null ));
85
- return new ModelTensorOutput (tensorOutputs );
111
+
112
+ try {
113
+ Output output = getPredictor ().predict (input );
114
+ ModelTensors tensors = parseModelTensorOutput (output , null );
115
+ return new ModelTensorOutput (List .of (tensors ));
116
+ } catch (Exception e ) {
117
+ log .error ("Error processing standard QA model prediction" , e );
118
+ throw new TranslateException ("Failed to process standard QA model prediction" , e );
119
+ }
120
+ }
121
+
122
+ private ModelTensorOutput predictSentenceHighlightingQA (String question , String context ) throws TranslateException {
123
+ SentenceHighlightingQATranslator translator = (SentenceHighlightingQATranslator ) getTranslator (PYTORCH_ENGINE , this .modelConfig );
124
+
125
+ try {
126
+ List <Map <String , Object >> allHighlights = new ArrayList <>();
127
+
128
+ // We need to process initial chunk first to get the overflow encodings
129
+ processChunk (question , context , HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING , allHighlights );
130
+
131
+ Encoding encodings = translator .getTokenizer ().encode (question , context );
132
+ Encoding [] overflowEncodings = encodings .getOverflowing ();
133
+
134
+ // Process overflow chunks if overflow encodings are present
135
+ if (overflowEncodings != null && overflowEncodings .length > 0 ) {
136
+ for (int i = 0 ; i < overflowEncodings .length ; i ++) {
137
+ processChunk (question , context , String .valueOf (i + 1 ), allHighlights );
138
+ }
139
+ }
140
+
141
+ return createHighlightOutput (allHighlights );
142
+ } catch (Exception e ) {
143
+ log .error ("Error processing sentence highlighting model prediction" , e );
144
+ throw new TranslateException ("Failed to process chunks for sentence highlighting" , e );
145
+ }
146
+ }
147
+
148
+ private void processChunk (String question , String context , String chunkNumber , List <Map <String , Object >> allHighlights )
149
+ throws TranslateException {
150
+ Input chunkInput = new Input ();
151
+ chunkInput .add (MLInput .QUESTION_FIELD , question );
152
+ chunkInput .add (MLInput .CONTEXT_FIELD , context );
153
+ chunkInput .add (HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY , chunkNumber );
154
+
155
+ // Use batchPredict to process the chunk for complete results, predict only return the first result which can cause loss of relevant
156
+ // results
157
+ List <Output > outputs = getPredictor ().batchPredict (List .of (chunkInput ));
158
+
159
+ if (outputs .isEmpty ()) {
160
+ return ;
161
+ }
162
+
163
+ for (Output output : outputs ) {
164
+ ModelTensors tensors = parseModelTensorOutput (output , null );
165
+ allHighlights .addAll (extractHighlights (tensors ));
166
+ }
167
+ }
168
+
169
+ /**
170
+ * Extract highlights from model tensors output
171
+ *
172
+ * @param tensors The model tensors to extract highlights from
173
+ * @return List of highlight data maps
174
+ */
175
+ private List <Map <String , Object >> extractHighlights (ModelTensors tensors ) throws TranslateException {
176
+ List <Map <String , Object >> highlights = new ArrayList <>();
177
+
178
+ for (ModelTensor tensor : tensors .getMlModelTensors ()) {
179
+ Map <String , ?> dataAsMap = tensor .getDataAsMap ();
180
+ if (dataAsMap != null && dataAsMap .containsKey (FIELD_HIGHLIGHTS )) {
181
+ try {
182
+ List <Map <String , Object >> tensorHighlights = (List <Map <String , Object >>) dataAsMap .get (FIELD_HIGHLIGHTS );
183
+ highlights .addAll (tensorHighlights );
184
+ } catch (ClassCastException e ) {
185
+ log .error ("Failed to cast highlights data to expected format" , e );
186
+ throw new TranslateException ("Failed to cast highlights data to expected format" , e );
187
+ }
188
+ }
189
+ }
190
+
191
+ return highlights ;
192
+ }
193
+
194
+ /**
195
+ * Create a model tensor output for highlights
196
+ *
197
+ * @param highlights The list of highlights to include
198
+ * @return ModelTensorOutput containing highlights
199
+ */
200
+ private ModelTensorOutput createHighlightOutput (List <Map <String , Object >> highlights ) {
201
+ Map <String , Object > combinedData = new HashMap <>();
202
+
203
+ // Remove duplicates and sort by position
204
+ List <Map <String , Object >> uniqueSortedHighlights = removeDuplicatesAndSort (highlights );
205
+
206
+ combinedData .put (FIELD_HIGHLIGHTS , uniqueSortedHighlights );
207
+
208
+ ModelTensor combinedTensor = ModelTensor .builder ().name (FIELD_HIGHLIGHTS ).dataAsMap (combinedData ).build ();
209
+
210
+ return new ModelTensorOutput (List .of (new ModelTensors (List .of (combinedTensor ))));
211
+ }
212
+
213
+ /**
214
+ * Removes duplicate sentences and sorts them by position
215
+ *
216
+ * @param highlights The list of highlights to process
217
+ * @return List of unique highlights sorted by position
218
+ */
219
+ private List <Map <String , Object >> removeDuplicatesAndSort (List <Map <String , Object >> highlights ) {
220
+ // Use a map to detect duplicates by position
221
+ Map <Number , Map <String , Object >> uniqueMap = new HashMap <>();
222
+
223
+ // Add each highlight to the map, using position as the key
224
+ for (Map <String , Object > highlight : highlights ) {
225
+ Number position = (Number ) highlight .get (FIELD_POSITION );
226
+ if (!uniqueMap .containsKey (position )) {
227
+ uniqueMap .put (position , highlight );
228
+ }
229
+ }
230
+
231
+ // Convert back to list
232
+ List <Map <String , Object >> uniqueHighlights = new ArrayList <>(uniqueMap .values ());
233
+
234
+ // Sort by position
235
+ uniqueHighlights .sort ((a , b ) -> {
236
+ Number posA = (Number ) a .get (FIELD_POSITION );
237
+ Number posB = (Number ) b .get (FIELD_POSITION );
238
+ return Double .compare (posA .doubleValue (), posB .doubleValue ());
239
+ });
240
+
241
+ return uniqueHighlights ;
86
242
}
87
243
88
244
@ Override
89
245
public Translator <Input , Output > getTranslator (String engine , MLModelConfig modelConfig ) throws IllegalArgumentException {
90
- if (isSentenceHighlightingType (modelConfig )) {
91
- return SentenceHighlightingQATranslator .createDefault ();
246
+ if (translator == null ) {
247
+ if (modelConfig != null && SENTENCE_HIGHLIGHTING_TYPE .equalsIgnoreCase (modelConfig .getModelType ())) {
248
+ translator = SentenceHighlightingQATranslator .create (modelConfig );
249
+ } else {
250
+ translator = new QuestionAnsweringTranslator ();
251
+ }
92
252
}
93
- return new QuestionAnsweringTranslator () ;
253
+ return translator ;
94
254
}
95
255
96
256
@ Override
0 commit comments