Skip to content

Commit 82da998

Browse files
authored
Update highlighting model translator to adapt new model (#3699)
* Update highlighting model translator to adapt new model Signed-off-by: Junqiu Lei <[email protected]> * Resolve PR feedback Signed-off-by: Junqiu Lei <[email protected]> * Resolve PR feedback 2 Signed-off-by: Junqiu Lei <[email protected]> * Resolve PR feedback 3 Signed-off-by: Junqiu Lei <[email protected]> --------- Signed-off-by: Junqiu Lei <[email protected]>
1 parent b9d5201 commit 82da998

File tree

8 files changed

+1386
-328
lines changed

8 files changed

+1386
-328
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QAConstants.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,31 @@ public final class QAConstants {
2323
// Context keys
2424
public static final String KEY_SENTENCES = "sentences";
2525

26+
// Sentence highlighting model predict chunk input key
27+
public static final String HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY = "chunk";
28+
public static final String HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING = "0";
29+
2630
// Model input names
2731
public static final String INPUT_IDS = "input_ids";
2832
public static final String ATTENTION_MASK = "attention_mask";
2933
public static final String TOKEN_TYPE_IDS = "token_type_ids";
34+
public static final String SENTENCE_IDS = "sentence_ids";
3035

3136
// Default values for warm-up
3237
public static final String DEFAULT_WARMUP_QUESTION = "How is the weather?";
3338
public static final String DEFAULT_WARMUP_CONTEXT = "The weather is nice, it is beautiful day. The sun is shining. The sky is blue.";
39+
40+
// Default model configuration
41+
public static final String TOKEN_MAX_LENGTH_KEY = "token_max_length";
42+
public static final Integer DEFAULT_TOKEN_MAX_LENGTH = 512;
43+
public static final String TOKEN_OVERLAP_STRIDE_LENGTH_KEY = "token_overlap_stride";
44+
public static final Integer DEFAULT_TOKEN_OVERLAP_STRIDE_LENGTH = 128;
45+
public static final String WITH_OVERFLOWING_TOKENS_KEY = "with_overflowing_tokens";
46+
public static final Boolean DEFAULT_WITH_OVERFLOWING_TOKENS = true;
47+
public static final String PADDING_KEY = "padding";
48+
public static final Boolean DEFAULT_PADDING = false;
49+
public static final String TOKENIZER_FILE_NAME = "tokenizer.json";
50+
// Special token value used to ignore tokens in sentence ID mapping
51+
public static final int IGNORE_TOKEN_ID = -100;
52+
public static final int CONTEXT_START_DEFAULT_INDEX = 0;
3453
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModel.java

Lines changed: 183 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,41 @@
55

66
package org.opensearch.ml.engine.algorithms.question_answering;
77

8+
import static org.opensearch.ml.engine.ModelHelper.PYTORCH_ENGINE;
89
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_CONTEXT;
910
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.DEFAULT_WARMUP_QUESTION;
11+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_HIGHLIGHTS;
12+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.FIELD_POSITION;
13+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY;
14+
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING;
1015
import static org.opensearch.ml.engine.algorithms.question_answering.QAConstants.SENTENCE_HIGHLIGHTING_TYPE;
1116

1217
import java.util.ArrayList;
18+
import java.util.HashMap;
1319
import java.util.List;
20+
import java.util.Map;
1421

1522
import org.opensearch.ml.common.FunctionName;
1623
import org.opensearch.ml.common.dataset.MLInputDataset;
1724
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
1825
import org.opensearch.ml.common.input.MLInput;
1926
import org.opensearch.ml.common.model.MLModelConfig;
27+
import org.opensearch.ml.common.output.model.ModelTensor;
2028
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2129
import org.opensearch.ml.common.output.model.ModelTensors;
2230
import org.opensearch.ml.engine.algorithms.DLModel;
2331
import org.opensearch.ml.engine.annotation.Function;
2432

33+
import ai.djl.huggingface.tokenizers.Encoding;
2534
import ai.djl.inference.Predictor;
2635
import ai.djl.modality.Input;
2736
import ai.djl.modality.Output;
2837
import ai.djl.translate.TranslateException;
2938
import ai.djl.translate.Translator;
3039
import ai.djl.translate.TranslatorFactory;
40+
import lombok.AllArgsConstructor;
41+
import lombok.Builder;
42+
import lombok.NoArgsConstructor;
3143
import lombok.extern.log4j.Log4j2;
3244

3345
/**
@@ -36,7 +48,12 @@
3648
*/
3749
@Log4j2
3850
@Function(FunctionName.QUESTION_ANSWERING)
51+
@Builder(toBuilder = true)
52+
@AllArgsConstructor
53+
@NoArgsConstructor
3954
public class QuestionAnsweringModel extends DLModel {
55+
private MLModelConfig modelConfig;
56+
private Translator<Input, Output> translator;
4057

4158
@Override
4259
public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
@@ -47,50 +64,193 @@ public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfi
4764
throw new IllegalArgumentException("model id is null");
4865
}
4966

67+
// Initialize model config from model if it exists, the model config field is required for sentence highlighting model.
68+
if (modelConfig != null) {
69+
this.modelConfig = modelConfig;
70+
}
71+
5072
// Create input for the predictor
5173
Input input = new Input();
52-
input.add(DEFAULT_WARMUP_QUESTION);
53-
input.add(DEFAULT_WARMUP_CONTEXT);
74+
75+
if (isSentenceHighlightingModel()) {
76+
input.add(MLInput.QUESTION_FIELD, DEFAULT_WARMUP_QUESTION);
77+
input.add(MLInput.CONTEXT_FIELD, DEFAULT_WARMUP_CONTEXT);
78+
// Add initial chunk number key value pair which is needed for sentence highlighting model
79+
input.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING);
80+
} else {
81+
input.add(DEFAULT_WARMUP_QUESTION);
82+
input.add(DEFAULT_WARMUP_CONTEXT);
83+
}
5484

5585
// Run prediction to warm up the model
5686
predictor.predict(input);
5787
}
5888

59-
/**
60-
* Checks if the model is configured for sentence highlighting.
61-
*
62-
* @param modelConfig The model configuration
63-
* @return true if the model is configured for sentence highlighting, false otherwise
64-
*/
65-
private boolean isSentenceHighlightingType(MLModelConfig modelConfig) {
66-
if (modelConfig != null) {
67-
return SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType());
68-
}
69-
return false;
70-
}
71-
7289
@Override
7390
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
7491
MLInputDataset inputDataSet = mlInput.getInputDataset();
75-
List<ModelTensors> tensorOutputs = new ArrayList<>();
76-
Output output;
7792
QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet) inputDataSet;
7893
String question = qaInputDataSet.getQuestion();
7994
String context = qaInputDataSet.getContext();
95+
96+
if (isSentenceHighlightingModel()) {
97+
return predictSentenceHighlightingQA(question, context);
98+
}
99+
100+
return predictStandardQA(question, context);
101+
}
102+
103+
private boolean isSentenceHighlightingModel() {
104+
return modelConfig != null && SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType());
105+
}
106+
107+
private ModelTensorOutput predictStandardQA(String question, String context) throws TranslateException {
80108
Input input = new Input();
81109
input.add(question);
82110
input.add(context);
83-
output = getPredictor().predict(input);
84-
tensorOutputs.add(parseModelTensorOutput(output, null));
85-
return new ModelTensorOutput(tensorOutputs);
111+
112+
try {
113+
Output output = getPredictor().predict(input);
114+
ModelTensors tensors = parseModelTensorOutput(output, null);
115+
return new ModelTensorOutput(List.of(tensors));
116+
} catch (Exception e) {
117+
log.error("Error processing standard QA model prediction", e);
118+
throw new TranslateException("Failed to process standard QA model prediction", e);
119+
}
120+
}
121+
122+
private ModelTensorOutput predictSentenceHighlightingQA(String question, String context) throws TranslateException {
123+
SentenceHighlightingQATranslator translator = (SentenceHighlightingQATranslator) getTranslator(PYTORCH_ENGINE, this.modelConfig);
124+
125+
try {
126+
List<Map<String, Object>> allHighlights = new ArrayList<>();
127+
128+
// We need to process initial chunk first to get the overflow encodings
129+
processChunk(question, context, HIGHLIGHTING_MODEL_INITIAL_CHUNK_NUMBER_STRING, allHighlights);
130+
131+
Encoding encodings = translator.getTokenizer().encode(question, context);
132+
Encoding[] overflowEncodings = encodings.getOverflowing();
133+
134+
// Process overflow chunks if overflow encodings are present
135+
if (overflowEncodings != null && overflowEncodings.length > 0) {
136+
for (int i = 0; i < overflowEncodings.length; i++) {
137+
processChunk(question, context, String.valueOf(i + 1), allHighlights);
138+
}
139+
}
140+
141+
return createHighlightOutput(allHighlights);
142+
} catch (Exception e) {
143+
log.error("Error processing sentence highlighting model prediction", e);
144+
throw new TranslateException("Failed to process chunks for sentence highlighting", e);
145+
}
146+
}
147+
148+
private void processChunk(String question, String context, String chunkNumber, List<Map<String, Object>> allHighlights)
149+
throws TranslateException {
150+
Input chunkInput = new Input();
151+
chunkInput.add(MLInput.QUESTION_FIELD, question);
152+
chunkInput.add(MLInput.CONTEXT_FIELD, context);
153+
chunkInput.add(HIGHLIGHTING_MODEL_CHUNK_NUMBER_KEY, chunkNumber);
154+
155+
// Use batchPredict to process the chunk for complete results, predict only return the first result which can cause loss of relevant
156+
// results
157+
List<Output> outputs = getPredictor().batchPredict(List.of(chunkInput));
158+
159+
if (outputs.isEmpty()) {
160+
return;
161+
}
162+
163+
for (Output output : outputs) {
164+
ModelTensors tensors = parseModelTensorOutput(output, null);
165+
allHighlights.addAll(extractHighlights(tensors));
166+
}
167+
}
168+
169+
/**
170+
* Extract highlights from model tensors output
171+
*
172+
* @param tensors The model tensors to extract highlights from
173+
* @return List of highlight data maps
174+
*/
175+
private List<Map<String, Object>> extractHighlights(ModelTensors tensors) throws TranslateException {
176+
List<Map<String, Object>> highlights = new ArrayList<>();
177+
178+
for (ModelTensor tensor : tensors.getMlModelTensors()) {
179+
Map<String, ?> dataAsMap = tensor.getDataAsMap();
180+
if (dataAsMap != null && dataAsMap.containsKey(FIELD_HIGHLIGHTS)) {
181+
try {
182+
List<Map<String, Object>> tensorHighlights = (List<Map<String, Object>>) dataAsMap.get(FIELD_HIGHLIGHTS);
183+
highlights.addAll(tensorHighlights);
184+
} catch (ClassCastException e) {
185+
log.error("Failed to cast highlights data to expected format", e);
186+
throw new TranslateException("Failed to cast highlights data to expected format", e);
187+
}
188+
}
189+
}
190+
191+
return highlights;
192+
}
193+
194+
/**
195+
* Create a model tensor output for highlights
196+
*
197+
* @param highlights The list of highlights to include
198+
* @return ModelTensorOutput containing highlights
199+
*/
200+
private ModelTensorOutput createHighlightOutput(List<Map<String, Object>> highlights) {
201+
Map<String, Object> combinedData = new HashMap<>();
202+
203+
// Remove duplicates and sort by position
204+
List<Map<String, Object>> uniqueSortedHighlights = removeDuplicatesAndSort(highlights);
205+
206+
combinedData.put(FIELD_HIGHLIGHTS, uniqueSortedHighlights);
207+
208+
ModelTensor combinedTensor = ModelTensor.builder().name(FIELD_HIGHLIGHTS).dataAsMap(combinedData).build();
209+
210+
return new ModelTensorOutput(List.of(new ModelTensors(List.of(combinedTensor))));
211+
}
212+
213+
/**
214+
* Removes duplicate sentences and sorts them by position
215+
*
216+
* @param highlights The list of highlights to process
217+
* @return List of unique highlights sorted by position
218+
*/
219+
private List<Map<String, Object>> removeDuplicatesAndSort(List<Map<String, Object>> highlights) {
220+
// Use a map to detect duplicates by position
221+
Map<Number, Map<String, Object>> uniqueMap = new HashMap<>();
222+
223+
// Add each highlight to the map, using position as the key
224+
for (Map<String, Object> highlight : highlights) {
225+
Number position = (Number) highlight.get(FIELD_POSITION);
226+
if (!uniqueMap.containsKey(position)) {
227+
uniqueMap.put(position, highlight);
228+
}
229+
}
230+
231+
// Convert back to list
232+
List<Map<String, Object>> uniqueHighlights = new ArrayList<>(uniqueMap.values());
233+
234+
// Sort by position
235+
uniqueHighlights.sort((a, b) -> {
236+
Number posA = (Number) a.get(FIELD_POSITION);
237+
Number posB = (Number) b.get(FIELD_POSITION);
238+
return Double.compare(posA.doubleValue(), posB.doubleValue());
239+
});
240+
241+
return uniqueHighlights;
86242
}
87243

88244
@Override
89245
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
90-
if (isSentenceHighlightingType(modelConfig)) {
91-
return SentenceHighlightingQATranslator.createDefault();
246+
if (translator == null) {
247+
if (modelConfig != null && SENTENCE_HIGHLIGHTING_TYPE.equalsIgnoreCase(modelConfig.getModelType())) {
248+
translator = SentenceHighlightingQATranslator.create(modelConfig);
249+
} else {
250+
translator = new QuestionAnsweringTranslator();
251+
}
92252
}
93-
return new QuestionAnsweringTranslator();
253+
return translator;
94254
}
95255

96256
@Override

0 commit comments

Comments
 (0)