Skip to content

Add new additional_config field to model_config and DefaultModelConfig class #3786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.GeneralModelConfig;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.model.MLDeploySetting;
import org.opensearch.ml.common.model.MLModelConfig;
Expand Down Expand Up @@ -278,8 +279,10 @@ public MLModel(StreamInput input) throws IOException {
modelConfig = new MetricsCorrelationModelConfig(input);
} else if (algorithm.equals(FunctionName.QUESTION_ANSWERING)) {
modelConfig = new QuestionAnsweringModelConfig(input);
} else {
} else if (algorithm.equals(FunctionName.TEXT_EMBEDDING)) {
modelConfig = new TextEmbeddingModelConfig(input);
} else {
modelConfig = new GeneralModelConfig(input);
}
}
if (input.readBoolean()) {
Expand Down Expand Up @@ -623,8 +626,10 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
modelConfig = MetricsCorrelationModelConfig.parse(parser);
} else if (FunctionName.QUESTION_ANSWERING.name().equals(algorithmName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);
} else {
} else if (FunctionName.TEXT_EMBEDDING.name().equals(algorithmName)) {
modelConfig = TextEmbeddingModelConfig.parse(parser);
} else {
modelConfig = GeneralModelConfig.parse(parser);
}
break;
case DEPLOY_SETTING_FIELD:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
public class GeneralModelConfig extends MLModelConfig {
public static final String PARSE_FIELD_NAME = "general";
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
GeneralModelConfig.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);

public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension";
public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";

private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;

@Builder(toBuilder = true)
public GeneralModelConfig(
String modelType,
Integer embeddingDimension,
FrameworkType frameworkType,
String allConfig,
PoolingMode poolingMode,
boolean normalizeResult,
Integer modelMaxLength,
Map<String, Object> additionalConfig
) {
super(modelType, allConfig, additionalConfig);
this.embeddingDimension = embeddingDimension;
this.frameworkType = frameworkType;
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;

validateNoDuplicateKeys(allConfig, additionalConfig);
}

public static GeneralModelConfig parse(XContentParser parser) throws IOException {
String modelType = null;
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;
Map<String, Object> additionalConfig = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_TYPE_FIELD:
modelType = parser.text();
break;
case EMBEDDING_DIMENSION_FIELD:
embeddingDimension = parser.intValue();
break;
case FRAMEWORK_TYPE_FIELD:
frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT));
break;
case ALL_CONFIG_FIELD:
allConfig = parser.text();
break;
case POOLING_MODE_FIELD:
poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
break;
case NORMALIZE_RESULT_FIELD:
normalizeResult = parser.booleanValue();
break;
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
case ADDITIONAL_CONFIG_FIELD:
additionalConfig = parser.map();
break;
default:
parser.skipChildren();
break;
}
}
return new GeneralModelConfig(
modelType,
embeddingDimension,
frameworkType,
allConfig,
poolingMode,
normalizeResult,
modelMaxLength,
additionalConfig
);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

public GeneralModelConfig(StreamInput in) throws IOException {
super(in);
embeddingDimension = in.readInt();
frameworkType = in.readEnum(FrameworkType.class);
if (in.readBoolean()) {
poolingMode = in.readEnum(PoolingMode.class);
} else {
poolingMode = null;
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(embeddingDimension);
out.writeEnum(frameworkType);
if (poolingMode != null) {
out.writeBoolean(true);
out.writeEnum(poolingMode);
} else {
out.writeBoolean(false);
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (modelType != null) {
builder.field(MODEL_TYPE_FIELD, modelType);
}
if (embeddingDimension != null) {
builder.field(EMBEDDING_DIMENSION_FIELD, embeddingDimension);
}
if (frameworkType != null) {
builder.field(FRAMEWORK_TYPE_FIELD, frameworkType);
}
if (allConfig != null) {
builder.field(ALL_CONFIG_FIELD, allConfig);
}
if (poolingMode != null) {
builder.field(POOLING_MODE_FIELD, poolingMode);
}
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
if (modelMaxLength != null) {
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
}
if (additionalConfig != null) {
builder.field(ADDITIONAL_CONFIG_FIELD, additionalConfig);
}
builder.endObject();
return builder;
}

public enum PoolingMode {
MEAN("mean"),
MEAN_SQRT_LEN("mean_sqrt_len"),
MAX("max"),
WEIGHTED_MEAN("weightedmean"),
CLS("cls"),
LAST_TOKEN("lasttoken");

private String name;

public String getName() {
return name;
}

PoolingMode(String name) {
this.name = name;
}

public static PoolingMode from(String value) {
try {
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");
}
}
}

public enum FrameworkType {
HUGGINGFACE_TRANSFORMERS,
SENTENCE_TRANSFORMERS,
HUGGINGFACE_TRANSFORMERS_NEURON;

public static FrameworkType from(String value) {
try {
return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong framework type");
}
}
}

private void validateNoDuplicateKeys(String allConfig, Map<String, Object> additionalConfig) {
if (allConfig == null || additionalConfig == null || additionalConfig.isEmpty()) {
return;
}

try {
// Parse allConfig JSON string to Map
Map<String, Object> allConfigMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), allConfig, false);

// Check for duplicate keys
Set<String> duplicateKeys = allConfigMap.keySet().stream().filter(additionalConfig::containsKey).collect(Collectors.toSet());

if (!duplicateKeys.isEmpty()) {
throw new Exception();
}
} catch (Exception e) {
throw new IllegalArgumentException(
"Duplicate keys found in both all_config and additional_config: " + String.join(", ", additionalConfig.keySet())
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.model;

import java.io.IOException;
import java.util.Map;

import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -21,9 +22,11 @@ public abstract class MLModelConfig implements ToXContentObject, NamedWriteable

public static final String MODEL_TYPE_FIELD = "model_type";
public static final String ALL_CONFIG_FIELD = "all_config";
public static final String ADDITIONAL_CONFIG_FIELD = "additional_config";

protected String modelType;
protected String allConfig;
protected Map<String, Object> additionalConfig;

public MLModelConfig(String modelType, String allConfig) {
if (modelType == null) {
Expand All @@ -33,13 +36,20 @@ public MLModelConfig(String modelType, String allConfig) {
this.allConfig = allConfig;
}

public MLModelConfig(String modelType, String allConfig, Map<String, Object> additionalConfig) {
this(modelType, allConfig);
this.additionalConfig = additionalConfig;
}

public MLModelConfig(StreamInput in) throws IOException {
this.modelType = in.readString();
this.allConfig = in.readOptionalString();
this.additionalConfig = in.readMap();
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelType);
out.writeOptionalString(allConfig);
out.writeMap(additionalConfig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.GeneralModelConfig;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.model.MLDeploySetting;
import org.opensearch.ml.common.model.MLModelConfig;
Expand Down Expand Up @@ -193,8 +194,10 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
this.modelConfig = new MetricsCorrelationModelConfig(in);
} else if (this.functionName.equals(FunctionName.QUESTION_ANSWERING)) {
this.modelConfig = new QuestionAnsweringModelConfig(in);
} else {
} else if (this.functionName.equals(FunctionName.TEXT_EMBEDDING)) {
this.modelConfig = new TextEmbeddingModelConfig(in);
} else {
this.modelConfig = new GeneralModelConfig(in);
}
}
this.deployModel = in.readBoolean();
Expand Down Expand Up @@ -449,8 +452,10 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
case MODEL_CONFIG_FIELD:
if (FunctionName.QUESTION_ANSWERING.equals(functionName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);
} else {
} else if (FunctionName.TEXT_EMBEDDING.equals(functionName)) {
modelConfig = TextEmbeddingModelConfig.parse(parser);
} else {
modelConfig = GeneralModelConfig.parse(parser);
}
break;
case DEPLOY_SETTING_FIELD:
Expand Down Expand Up @@ -598,8 +603,10 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
case MODEL_CONFIG_FIELD:
if (FunctionName.QUESTION_ANSWERING.equals(functionName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);
} else {
} else if (FunctionName.TEXT_EMBEDDING.equals(functionName)) {
modelConfig = TextEmbeddingModelConfig.parse(parser);
} else {
modelConfig = GeneralModelConfig.parse(parser);
}
break;
case DEPLOY_SETTING_FIELD:
Expand Down
Loading
Loading