diff --git a/CHANGELOG.md b/CHANGELOG.md index b11ffb7e3bd19..6a31b5e21d607 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Rule based auto-tagging] Add get rule API ([#17336](https://github.com/opensearch-project/OpenSearch/pull/17336)) - [Rule based auto-tagging] Add Delete Rule API ([#18184](https://github.com/opensearch-project/OpenSearch/pull/18184)) - Add paginated wlm/stats API ([#17638](https://github.com/opensearch-project/OpenSearch/pull/17638)) +- [Rule based auto-tagging] Add Create rule API ([#17792](https://github.com/opensearch-project/OpenSearch/pull/17792)) - Implement parallel shard refresh behind cluster settings ([#17782](https://github.com/opensearch-project/OpenSearch/pull/17782)) - Bump OpenSearch Core main branch to 3.0.0 ([#18039](https://github.com/opensearch-project/OpenSearch/pull/18039)) - [Rule based Auto-tagging] Add wlm `ActionFilter` ([#17791](https://github.com/opensearch-project/OpenSearch/pull/17791)) diff --git a/modules/autotagging-commons/build.gradle b/modules/autotagging-commons/build.gradle index 5c996fc85bea9..8b9c4fbb7d409 100644 --- a/modules/autotagging-commons/build.gradle +++ b/modules/autotagging-commons/build.gradle @@ -6,7 +6,6 @@ * compatible open source license. */ - opensearchplugin { name = "rule-framework" description = 'OpenSearch Rule Framework plugin' diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/CreateRuleRequest.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/CreateRuleRequest.java new file mode 100644 index 0000000000000..7963357b893d9 --- /dev/null +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/CreateRuleRequest.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.rule.autotagging.Rule; + +import java.io.IOException; + +/** + * A request for create Rule + * Example request: + * curl -X PUT "localhost:9200/_rules/{featureType}/" -H 'Content-Type: application/json' -d ' + * { + * "description": "description1", + * "attribute_name": ["log*", "event*"], + * "feature_type": "poOiU851RwyLYvV5lbvv5w" + * }' + * @opensearch.experimental + */ +public class CreateRuleRequest extends ActionRequest { + private final Rule rule; + + /** + * constructor for CreateRuleRequest + * @param rule the rule to create + */ + public CreateRuleRequest(Rule rule) { + this.rule = rule; + } + + /** + * Constructs a CreateRuleRequest from a StreamInput for deserialization + * @param in - The {@link StreamInput} instance to read from. + */ + public CreateRuleRequest(StreamInput in) throws IOException { + super(in); + rule = new Rule(in); + } + + @Override + public ActionRequestValidationException validate() { + try { + rule.getFeatureType().getFeatureValueValidator().validate(rule.getFeatureValue()); + return null; + } catch (Exception e) { + ActionRequestValidationException validationException = new ActionRequestValidationException(); + validationException.addValidationError("Validation failed: " + e.getMessage()); + return validationException; + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + rule.writeTo(out); + } + + /** + * rule getter + */ + public Rule getRule() { + return rule; + } +} diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/CreateRuleResponse.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/CreateRuleResponse.java new file mode 100644 index 0000000000000..f040372b69335 --- /dev/null +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/CreateRuleResponse.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rule.autotagging.Rule; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.rule.autotagging.Rule._ID_STRING; + +/** + * Response for the create API for Rule + * Example response: + * { + * "_id":"wi6VApYBoX5wstmtU_8l", + * "description":"description1", + * "index_pattern":["log*", "uvent*"], + * "workload_group":"poOiU851RwyLYvV5lbvv5w", + * "updated_at":"2025-04-04T20:54:22.406Z" + * } + * @opensearch.experimental + */ +public class CreateRuleResponse extends ActionResponse implements ToXContent, ToXContentObject { + private final String _id; + private final Rule rule; + + /** + * contructor for CreateRuleResponse + * @param id - the id for the rule created + * @param rule - the rule created + */ + public CreateRuleResponse(String id, final Rule rule) { + this._id = id; + this.rule = rule; + } + + /** + * Constructs a CreateRuleResponse from a StreamInput for deserialization + * @param in - The {@link StreamInput} instance to read from. + */ + public CreateRuleResponse(StreamInput in) throws IOException { + _id = in.readString(); + rule = new Rule(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(_id); + rule.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return rule.toXContent(builder, new MapParams(Map.of(_ID_STRING, _id))); + } + + /** + * rule getter + */ + public Rule getRule() { + return rule; + } +} diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleRequest.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleRequest.java index 9cdebb782edc9..630a329688b2b 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleRequest.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleRequest.java @@ -26,10 +26,9 @@ /** * A request for get Rule * Example Request: - * The endpoint "localhost:9200/_wlm/rule" is specific to the Workload Management feature to manage rules - * curl -X GET "localhost:9200/_wlm/rule" - get all rules - * curl -X GET "localhost:9200/_wlm/rule/{_id}" - get single rule by id - * curl -X GET "localhost:9200/_wlm/rule?index_pattern=a,b" - get all rules containing attribute index_pattern as a or b + * curl -X GET "localhost:9200/_rules/{featureType}/" - get all rules for {featureType} + * curl -X GET "localhost:9200/_rules/{featureType}/{_id}" - get single rule by id + * curl -X GET "localhost:9200/_rules/{featureType}?index_pattern=a,b" - get all rules containing attribute index_pattern as a or b for {featureType} * @opensearch.experimental */ @ExperimentalApi diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleResponse.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleResponse.java index e3c0bb49043a7..2ce79850084db 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleResponse.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/GetRuleResponse.java @@ -29,9 +29,9 @@ * "rules": [ * { * "_id": "z1MJApUB0zgMcDmz-UQq", - * "description": "Rule for tagging query_group_id to index123" + * "description": "Rule for tagging workload_group_id to index123" * "index_pattern": ["index123"], - * "query_group": "query_group_id", + * "workload_group": "workload_group_id", * "updated_at": "2025-02-14T01:19:22.589Z" * }, * ... diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RulePersistenceService.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RulePersistenceService.java index 674e77a990849..b29323da421e7 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RulePersistenceService.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RulePersistenceService.java @@ -17,6 +17,13 @@ */ public interface RulePersistenceService { + /** + * Create rules based on the provided request. + * @param request The request containing the details for creating the rule. + * @param listener The listener that will handle the response or failure. + */ + void createRule(CreateRuleRequest request, ActionListener listener); + /** * Get rules based on the provided request. * @param request The request containing the details for retrieving the rule. diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RuleRoutingService.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RuleRoutingService.java new file mode 100644 index 0000000000000..e0d08f371a2aa --- /dev/null +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RuleRoutingService.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule; + +import org.opensearch.core.action.ActionListener; + +/** + * Interface that handles rule routing logic + * @opensearch.experimental + */ +public interface RuleRoutingService { + + /** + * Handles a create rule request by routing it to the appropriate node. + * @param request the create rule request + * @param listener listener to handle the final response + */ + void handleCreateRuleRequest(CreateRuleRequest request, ActionListener listener); +} diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RuleUtils.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RuleUtils.java new file mode 100644 index 0000000000000..7c66eac988f9b --- /dev/null +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/RuleUtils.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule; + +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.rule.autotagging.Attribute; +import org.opensearch.rule.autotagging.Rule; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * Utility class for operations related to {@link Rule} objects. + * @opensearch.experimental + */ +@ExperimentalApi +public class RuleUtils { + + /** + * constructor for RuleUtils + */ + public RuleUtils() {} + + /** + * Checks if a duplicate rule exists and returns its id. + * Two rules are considered to be duplicate when meeting all the criteria below + * 1. They have the same feature type + * 2. They have the exact same attributes + * 3. For each attribute, the sets of values must intersect — i.e., at least one common value must exist + * between the current rule and the one being checked. + * + * @param rule The rule to be validated against ruleMap. + * @param ruleMap This map contains existing rules to be checked + */ + public static Optional getDuplicateRuleId(Rule rule, Map ruleMap) { + Map> targetAttributeMap = rule.getAttributeMap(); + for (Map.Entry entry : ruleMap.entrySet()) { + Rule currRule = entry.getValue(); + Map> existingAttributeMap = currRule.getAttributeMap(); + + if (rule.getFeatureType() != currRule.getFeatureType() || targetAttributeMap.size() != existingAttributeMap.size()) { + continue; + } + boolean allAttributesIntersect = true; + for (Attribute attribute : targetAttributeMap.keySet()) { + Set targetAttributeValues = targetAttributeMap.get(attribute); + Set existingAttributeValues = existingAttributeMap.get(attribute); + if (existingAttributeValues == null || Collections.disjoint(targetAttributeValues, existingAttributeValues)) { + allAttributesIntersect = false; + break; + } + } + if (allAttributesIntersect) { + return Optional.of(entry.getKey()); + } + } + return Optional.empty(); + } +} diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/AutoTaggingRegistry.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/AutoTaggingRegistry.java index 9cfc5ccb1e342..be817e66cbd7a 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/AutoTaggingRegistry.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/AutoTaggingRegistry.java @@ -59,6 +59,9 @@ private static void validateFeatureType(FeatureType featureType) { "Feature type name " + name + " should not be null, empty or have more than " + MAX_FEATURE_TYPE_NAME_LENGTH + "characters" ); } + if (featureType.getFeatureValueValidator() == null) { + throw new IllegalStateException("FeatureValueValidator is not defined for feature type " + name); + } } /** diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureType.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureType.java index c752f917264de..9fc2cf62b462d 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureType.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureType.java @@ -17,13 +17,11 @@ /** * Represents a feature type within the auto-tagging feature. Feature types define different categories of - * characteristics that can be used for tagging and classification. Implementations of this interface are - * responsible for registering feature types in {@link AutoTaggingRegistry}. Implementations must ensure that + * characteristics that can be used for tagging and classification. Implementations must ensure that * feature types are uniquely identifiable by their class and name. * * Implementers should follow these guidelines: * Feature types should be singletons and managed centrally to avoid duplicates. - * {@link #registerFeatureType()} must be called during initialization to ensure the feature type is available. * * @opensearch.experimental */ @@ -49,6 +47,16 @@ public interface FeatureType extends Writeable { */ Map getAllowedAttributesRegistry(); + /** + * returns the validator for feature value + */ + default FeatureValueValidator getFeatureValueValidator() { + return new FeatureValueValidator() { + @Override + public void validate(String featureValue) {} + }; + } + /** * returns max attribute values * @return @@ -65,11 +73,6 @@ default int getMaxCharLengthPerAttributeValue() { return DEFAULT_MAX_ATTRIBUTE_VALUE_LENGTH; } - /** - * makes the feature type usable and available to framework plugin - */ - void registerFeatureType(); - /** * checks the validity of the input attribute * @param attribute diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureValueValidator.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureValueValidator.java new file mode 100644 index 0000000000000..7c21982ce9f2a --- /dev/null +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/autotagging/FeatureValueValidator.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.autotagging; + +/** + * Interface for validating a feature value against pre-defined values (such as + * values from the index, cluster state, etc.) for a specific feature type. + * @opensearch.experimental + */ +public interface FeatureValueValidator { + /** + * Validates the given feature value. + * @param featureValue the value to validate + */ + void validate(String featureValue); +} diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/service/IndexStoredRulePersistenceService.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/service/IndexStoredRulePersistenceService.java index f42ea1e951e4f..b0d31a829b2b6 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/service/IndexStoredRulePersistenceService.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/service/IndexStoredRulePersistenceService.java @@ -13,19 +13,27 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequestBuilder; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.clustermanager.AcknowledgedResponse; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.rule.CreateRuleResponse; import org.opensearch.rule.DeleteRuleRequest; import org.opensearch.rule.GetRuleRequest; import org.opensearch.rule.GetRuleResponse; import org.opensearch.rule.RuleEntityParser; import org.opensearch.rule.RulePersistenceService; import org.opensearch.rule.RuleQueryMapper; +import org.opensearch.rule.RuleUtils; import org.opensearch.rule.autotagging.Rule; import org.opensearch.search.SearchHit; import org.opensearch.search.sort.SortOrder; @@ -34,6 +42,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import static org.opensearch.rule.autotagging.Rule._ID_STRING; @@ -48,6 +57,7 @@ public class IndexStoredRulePersistenceService implements RulePersistenceService */ private final String indexName; private final Client client; + private final ClusterService clusterService; private final int maxRulesPerPage; private final RuleEntityParser parser; private final RuleQueryMapper queryBuilder; @@ -58,6 +68,7 @@ public class IndexStoredRulePersistenceService implements RulePersistenceService * This service handles persistence and retrieval of stored rules within an OpenSearch index. * @param indexName - The name of the OpenSearch index where the rules are stored. * @param client - The OpenSearch client used to interact with the OpenSearch cluster. + * @param clusterService * @param maxRulesPerPage - The maximum number of rules that can be returned in a single get request. * @param parser * @param queryBuilder @@ -65,26 +76,90 @@ public class IndexStoredRulePersistenceService implements RulePersistenceService public IndexStoredRulePersistenceService( String indexName, Client client, + ClusterService clusterService, int maxRulesPerPage, RuleEntityParser parser, RuleQueryMapper queryBuilder ) { this.indexName = indexName; this.client = client; + this.clusterService = clusterService; this.maxRulesPerPage = maxRulesPerPage; this.parser = parser; this.queryBuilder = queryBuilder; } + /** + * Entry point for the create rule API logic in persistence service. + * It ensures the index exists, validates for duplicate rules, and persists the new rule. + * @param request The CreateRuleRequest + * @param listener ActionListener for CreateRuleResponse + */ + public void createRule(CreateRuleRequest request, ActionListener listener) { + try (ThreadContext.StoredContext ctx = stashContext()) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + logger.error("Index {} does not exist", indexName); + listener.onFailure(new IllegalStateException("Index" + indexName + " does not exist")); + } else { + Rule rule = request.getRule(); + validateNoDuplicateRule(rule, ActionListener.wrap(unused -> persistRule(rule, listener), listener::onFailure)); + } + } + } + + /** + * Validates that no existing rule has the same attribute map as the given rule. + * This validation must be performed one at a time to prevent writing duplicate rules. + * @param rule - the rule we check duplicate against + * @param listener - listener for validateNoDuplicateRule response + */ + private void validateNoDuplicateRule(Rule rule, ActionListener listener) { + QueryBuilder query = queryBuilder.from(new GetRuleRequest(null, rule.getAttributeMap(), null, rule.getFeatureType())); + getRuleFromIndex(null, query, null, new ActionListener<>() { + @Override + public void onResponse(GetRuleResponse getRuleResponse) { + Optional duplicateRuleId = RuleUtils.getDuplicateRuleId(rule, getRuleResponse.getRules()); + duplicateRuleId.ifPresentOrElse( + id -> listener.onFailure(new IllegalArgumentException("Duplicate rule exists under id " + id)), + () -> listener.onResponse(null) + ); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + /** + * Persist the rule in the index + * @param rule - The rule to update. + * @param listener - ActionListener for CreateRuleResponse + */ + private void persistRule(Rule rule, ActionListener listener) { + try { + IndexRequest indexRequest = new IndexRequest(indexName).source( + rule.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS) + ); + IndexResponse indexResponse = client.index(indexRequest).get(); + listener.onResponse(new CreateRuleResponse(indexResponse.getId(), rule)); + } catch (Exception e) { + logger.error("Error saving rule to index: {}", indexName); + listener.onFailure(new RuntimeException("Failed to save rule to index.")); + } + } + /** * Entry point for the get rule api logic in persistence service. * @param getRuleRequest the getRuleRequest to process. * @param listener the listener for GetRuleResponse. */ public void getRule(GetRuleRequest getRuleRequest, ActionListener listener) { - final QueryBuilder getQueryBuilder = queryBuilder.from(getRuleRequest) - .filter(QueryBuilders.existsQuery(getRuleRequest.getFeatureType().getName())); - getRuleFromIndex(getRuleRequest.getId(), getQueryBuilder, getRuleRequest.getSearchAfter(), listener); + try (ThreadContext.StoredContext context = stashContext()) { + final QueryBuilder getQueryBuilder = queryBuilder.from(getRuleRequest); + getRuleFromIndex(getRuleRequest.getId(), getQueryBuilder, getRuleRequest.getSearchAfter(), listener); + } } /** @@ -95,22 +170,19 @@ public void getRule(GetRuleRequest getRuleRequest, ActionListener listener) { - // Stash the current thread context when interacting with system index to perform - // operations as the system itself, bypassing authorization checks. This ensures that - // actions within this block are trusted and executed with system-level privileges. - try (ThreadContext.StoredContext context = getContext()) { + try { SearchRequestBuilder searchRequest = client.prepareSearch(indexName).setQuery(queryBuilder).setSize(maxRulesPerPage); if (searchAfter != null) { searchRequest.addSort(_ID_STRING, SortOrder.ASC).searchAfter(new Object[] { searchAfter }); } - searchRequest.execute(ActionListener.wrap(searchResponse -> { - List hits = Arrays.asList(searchResponse.getHits().getHits()); - if (hasNoResults(id, listener, hits)) return; - handleGetRuleResponse(hits, listener); - }, e -> { - logger.error("Failed to fetch all rules: {}", e.getMessage()); - listener.onFailure(e); - })); + + SearchResponse searchResponse = searchRequest.get(); + List hits = Arrays.asList(searchResponse.getHits().getHits()); + if (hasNoResults(id, listener, hits)) return; + handleGetRuleResponse(hits, listener); + } catch (Exception e) { + logger.error("Failed to fetch all rules: {}", e.getMessage()); + listener.onFailure(e); } } @@ -134,13 +206,9 @@ void handleGetRuleResponse(List hits, ActionListener listener.onResponse(new GetRuleResponse(ruleMap, nextSearchAfter)); } - private ThreadContext.StoredContext getContext() { - return client.threadPool().getThreadContext().stashContext(); - } - @Override public void deleteRule(DeleteRuleRequest request, ActionListener listener) { - try (ThreadContext.StoredContext context = getContext()) { + try (ThreadContext.StoredContext context = stashContext()) { DeleteRequest deleteRequest = new DeleteRequest(indexName).id(request.getRuleId()); client.delete(deleteRequest, ActionListener.wrap(deleteResponse -> { boolean acknowledged = deleteResponse.getResult() == DocWriteResponse.Result.DELETED; @@ -159,4 +227,15 @@ public void deleteRule(DeleteRuleRequest request, ActionListener> attributeFilters = request.getAttributeFilters(); final String id = request.getId(); + boolQuery.filter(QueryBuilders.existsQuery(request.getFeatureType().getName())); if (id != null) { return boolQuery.must(QueryBuilders.termQuery(_ID_STRING, id)); } diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/RuleUtilsTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/RuleUtilsTests.java new file mode 100644 index 0000000000000..2780f329925c9 --- /dev/null +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/RuleUtilsTests.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule; + +import org.opensearch.rule.autotagging.Rule; +import org.opensearch.rule.autotagging.RuleTests; +import org.opensearch.rule.utils.RuleTestUtils; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.opensearch.rule.action.GetRuleResponseTests.ruleOne; +import static org.opensearch.rule.utils.RuleTestUtils.ATTRIBUTE_VALUE_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.ATTRIBUTE_VALUE_TWO; +import static org.opensearch.rule.utils.RuleTestUtils.DESCRIPTION_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.FEATURE_VALUE_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.TIMESTAMP_ONE; +import static org.opensearch.rule.utils.RuleTestUtils._ID_ONE; +import static org.opensearch.rule.utils.RuleTestUtils._ID_TWO; +import static org.opensearch.rule.utils.RuleTestUtils.ruleTwo; + +public class RuleUtilsTests extends OpenSearchTestCase { + + public void testDuplicateRuleFound() { + Optional result = RuleUtils.getDuplicateRuleId(ruleOne, Map.of(_ID_ONE, ruleOne, _ID_TWO, ruleTwo)); + assertTrue(result.isPresent()); + assertEquals(_ID_ONE, result.get()); + } + + public void testNoAttributeIntersection() { + Optional result = RuleUtils.getDuplicateRuleId(ruleOne, Map.of(_ID_TWO, ruleTwo)); + assertTrue(result.isEmpty()); + } + + public void testAttributeSizeMismatch() { + Rule testRule = Rule.builder() + .description(DESCRIPTION_ONE) + .featureType(RuleTestUtils.MockRuleFeatureType.INSTANCE) + .featureValue(FEATURE_VALUE_ONE) + .attributeMap( + Map.of( + RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_ONE, + Set.of(ATTRIBUTE_VALUE_ONE), + RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_TWO, + Set.of(ATTRIBUTE_VALUE_TWO) + ) + ) + .updatedAt(TIMESTAMP_ONE) + .build(); + Optional result = RuleUtils.getDuplicateRuleId(ruleOne, Map.of(_ID_TWO, testRule)); + assertTrue(result.isEmpty()); + } + + public void testPartialAttributeValueIntersection() { + Rule ruleWithPartialOverlap = Rule.builder() + .description(DESCRIPTION_ONE) + .featureType(RuleTestUtils.MockRuleFeatureType.INSTANCE) + .featureValue(FEATURE_VALUE_ONE) + .attributeMap(Map.of(RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_ONE, Set.of(ATTRIBUTE_VALUE_ONE, "extra_value"))) + .updatedAt(TIMESTAMP_ONE) + .build(); + + Optional result = RuleUtils.getDuplicateRuleId(ruleWithPartialOverlap, Map.of(_ID_ONE, ruleOne)); + assertTrue(result.isPresent()); + assertEquals(_ID_ONE, result.get()); + } + + public void testDifferentFeatureTypes() { + Rule differentFeatureTypeRule = Rule.builder() + .description(DESCRIPTION_ONE) + .featureType(RuleTests.TestFeatureType.INSTANCE) + .featureValue(FEATURE_VALUE_ONE) + .attributeMap(RuleTests.ATTRIBUTE_MAP) + .updatedAt(TIMESTAMP_ONE) + .build(); + + Optional result = RuleUtils.getDuplicateRuleId(differentFeatureTypeRule, Map.of(_ID_ONE, ruleOne)); + assertTrue(result.isEmpty()); + } +} diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/CreateRuleRequestTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/CreateRuleRequestTests.java new file mode 100644 index 0000000000000..7804714c37dcc --- /dev/null +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/CreateRuleRequestTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.action; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.opensearch.rule.utils.RuleTestUtils.assertEqualRule; +import static org.opensearch.rule.utils.RuleTestUtils.ruleOne; + +public class CreateRuleRequestTests extends OpenSearchTestCase { + + /** + * Test case to verify the serialization and deserialization of CreateRuleRequest. + */ + public void testSerialization() throws IOException { + CreateRuleRequest request = new CreateRuleRequest(ruleOne); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput streamInput = out.bytes().streamInput(); + CreateRuleRequest otherRequest = new CreateRuleRequest(streamInput); + assertEqualRule(ruleOne, otherRequest.getRule(), false); + } +} diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/CreateRuleResponseTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/CreateRuleResponseTests.java new file mode 100644 index 0000000000000..dc445dad2e82c --- /dev/null +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/CreateRuleResponseTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.action; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rule.CreateRuleResponse; +import org.opensearch.rule.autotagging.Rule; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.rule.action.GetRuleResponseTests.ruleOne; +import static org.opensearch.rule.utils.RuleTestUtils._ID_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.assertEqualRules; +import static org.mockito.Mockito.mock; + +public class CreateRuleResponseTests extends OpenSearchTestCase { + + /** + * Test case to verify serialization and deserialization of CreateRuleResponse + */ + public void testSerialization() throws IOException { + CreateRuleResponse response = new CreateRuleResponse(_ID_ONE, ruleOne); + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + StreamInput streamInput = out.bytes().streamInput(); + CreateRuleResponse otherResponse = new CreateRuleResponse(streamInput); + Rule responseRule = response.getRule(); + Rule otherResponseRule = otherResponse.getRule(); + assertEqualRules(Map.of(_ID_ONE, responseRule), Map.of(_ID_ONE, otherResponseRule), false); + } + + /** + * Test case to validate the toXContent method of CreateRuleResponse + */ + public void testToXContentCreateRule() throws IOException { + XContentBuilder builder = JsonXContent.contentBuilder().prettyPrint(); + CreateRuleResponse response = new CreateRuleResponse(_ID_ONE, ruleOne); + String actual = response.toXContent(builder, mock(ToXContent.Params.class)).toString(); + String expected = "{\n" + + " \"_id\" : \"AgfUO5Ja9yfvhdONlYi3TQ==\",\n" + + " \"description\" : \"description_1\",\n" + + " \"mock_attribute_one\" : [\n" + + " \"mock_attribute_one\"\n" + + " ],\n" + + " \"mock_feature_type\" : \"feature_value_one\",\n" + + " \"updated_at\" : \"2024-01-26T08:58:57.558Z\"\n" + + "}"; + assertEquals(expected, actual); + } +} diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/DeleteRuleRequestTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/DeleteRuleRequestTests.java index 315f94c0e0437..55213a245b5ad 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/DeleteRuleRequestTests.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/DeleteRuleRequestTests.java @@ -16,7 +16,7 @@ import java.io.IOException; -import static org.opensearch.rule.action.GetRuleRequestTests._ID_ONE; +import static org.opensearch.rule.utils.RuleTestUtils._ID_ONE; public class DeleteRuleRequestTests extends OpenSearchTestCase { diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleRequestTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleRequestTests.java index a451a58356606..c904373588f17 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleRequestTests.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleRequestTests.java @@ -11,15 +11,15 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.rule.GetRuleRequest; -import org.opensearch.rule.autotagging.Attribute; -import org.opensearch.rule.autotagging.Rule; import org.opensearch.rule.utils.RuleTestUtils; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.util.HashMap; -import java.util.Map; -import java.util.Set; + +import static org.opensearch.rule.utils.RuleTestUtils.ATTRIBUTE_MAP; +import static org.opensearch.rule.utils.RuleTestUtils.SEARCH_AFTER; +import static org.opensearch.rule.utils.RuleTestUtils._ID_ONE; public class GetRuleRequestTests extends OpenSearchTestCase { /** @@ -64,63 +64,4 @@ public void testValidate() { request = new GetRuleRequest(_ID_ONE, ATTRIBUTE_MAP, "", RuleTestUtils.MockRuleFeatureType.INSTANCE); assertThrows(IllegalArgumentException.class, request::validate); } - - public static final String _ID_ONE = "id_1"; - public static final String SEARCH_AFTER = "search_after"; - public static final String _ID_TWO = "G5iIq84j7eK1qIAAAAIH53=1"; - public static final String FEATURE_VALUE_ONE = "feature_value_one"; - public static final String FEATURE_VALUE_TWO = "feature_value_two"; - public static final String ATTRIBUTE_VALUE_ONE = "mock_attribute_one"; - public static final String ATTRIBUTE_VALUE_TWO = "mock_attribute_two"; - public static final String DESCRIPTION_ONE = "description_1"; - public static final String DESCRIPTION_TWO = "description_2"; - public static final String TIMESTAMP_ONE = "2024-01-26T08:58:57.558Z"; - public static final String TIMESTAMP_TWO = "2023-01-26T08:58:57.558Z"; - public static final Map> ATTRIBUTE_MAP = Map.of( - RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_ONE, - Set.of(ATTRIBUTE_VALUE_ONE) - ); - - public static final Rule ruleOne = Rule.builder() - .description(DESCRIPTION_ONE) - .featureType(RuleTestUtils.MockRuleFeatureType.INSTANCE) - .featureValue(FEATURE_VALUE_ONE) - .attributeMap(ATTRIBUTE_MAP) - .updatedAt(TIMESTAMP_ONE) - .build(); - - public static final Rule ruleTwo = Rule.builder() - .description(DESCRIPTION_TWO) - .featureType(RuleTestUtils.MockRuleFeatureType.INSTANCE) - .featureValue(FEATURE_VALUE_TWO) - .attributeMap(Map.of(RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_TWO, Set.of(ATTRIBUTE_VALUE_TWO))) - .updatedAt(TIMESTAMP_TWO) - .build(); - - public static Map ruleMap() { - return Map.of(_ID_ONE, ruleOne, _ID_TWO, ruleTwo); - } - - public static void assertEqualRules(Map mapOne, Map mapTwo, boolean ruleUpdated) { - assertEquals(mapOne.size(), mapTwo.size()); - for (Map.Entry entry : mapOne.entrySet()) { - String id = entry.getKey(); - assertTrue(mapTwo.containsKey(id)); - Rule one = mapOne.get(id); - Rule two = mapTwo.get(id); - assertEqualRule(one, two, ruleUpdated); - } - } - - public static void assertEqualRule(Rule one, Rule two, boolean ruleUpdated) { - if (ruleUpdated) { - assertEquals(one.getDescription(), two.getDescription()); - assertEquals(one.getFeatureType(), two.getFeatureType()); - assertEquals(one.getFeatureValue(), two.getFeatureValue()); - assertEquals(one.getAttributeMap(), two.getAttributeMap()); - assertEquals(one.getAttributeMap(), two.getAttributeMap()); - } else { - assertEquals(one, two); - } - } } diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleResponseTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleResponseTests.java index f01bb94fab276..4a8fc3ef6d7bc 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleResponseTests.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/action/GetRuleResponseTests.java @@ -24,10 +24,10 @@ import java.util.Map; import java.util.Set; -import static org.opensearch.rule.action.GetRuleRequestTests.SEARCH_AFTER; -import static org.opensearch.rule.action.GetRuleRequestTests._ID_ONE; -import static org.opensearch.rule.action.GetRuleRequestTests.assertEqualRules; -import static org.opensearch.rule.action.GetRuleRequestTests.ruleMap; +import static org.opensearch.rule.utils.RuleTestUtils.SEARCH_AFTER; +import static org.opensearch.rule.utils.RuleTestUtils._ID_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.assertEqualRules; +import static org.opensearch.rule.utils.RuleTestUtils.ruleMap; import static org.mockito.Mockito.mock; public class GetRuleResponseTests extends OpenSearchTestCase { @@ -109,7 +109,7 @@ public void testToXContentGetSingleRule() throws IOException { String expected = "{\n" + " \"rules\" : [\n" + " {\n" - + " \"_id\" : \"id_1\",\n" + + " \"_id\" : \"AgfUO5Ja9yfvhdONlYi3TQ==\",\n" + " \"description\" : \"description_1\",\n" + " \"mock_attribute_one\" : [\n" + " \"mock_attribute_one\"\n" diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/AutoTaggingRegistryTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/AutoTaggingRegistryTests.java index eee1d527dc6e9..56adf333e1a1c 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/AutoTaggingRegistryTests.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/AutoTaggingRegistryTests.java @@ -9,12 +9,13 @@ package org.opensearch.rule.autotagging; import org.opensearch.ResourceNotFoundException; +import org.opensearch.rule.utils.RuleTestUtils; import org.opensearch.test.OpenSearchTestCase; import org.junit.BeforeClass; import static org.opensearch.rule.autotagging.AutoTaggingRegistry.MAX_FEATURE_TYPE_NAME_LENGTH; import static org.opensearch.rule.autotagging.RuleTests.INVALID_FEATURE; -import static org.opensearch.rule.autotagging.RuleTests.TEST_FEATURE_TYPE; +import static org.opensearch.rule.utils.RuleTestUtils.FEATURE_TYPE_NAME; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -22,14 +23,13 @@ public class AutoTaggingRegistryTests extends OpenSearchTestCase { @BeforeClass public static void setUpOnce() { - FeatureType featureType = mock(FeatureType.class); - when(featureType.getName()).thenReturn(TEST_FEATURE_TYPE); + FeatureType featureType = RuleTestUtils.MockRuleFeatureType.INSTANCE; AutoTaggingRegistry.registerFeatureType(featureType); } public void testGetFeatureType_Success() { - FeatureType retrievedFeatureType = AutoTaggingRegistry.getFeatureType(TEST_FEATURE_TYPE); - assertEquals(TEST_FEATURE_TYPE, retrievedFeatureType.getName()); + FeatureType retrievedFeatureType = AutoTaggingRegistry.getFeatureType(FEATURE_TYPE_NAME); + assertEquals(FEATURE_TYPE_NAME, retrievedFeatureType.getName()); } public void testRuntimeException() { @@ -39,7 +39,7 @@ public void testRuntimeException() { public void testIllegalStateExceptionException() { assertThrows(IllegalStateException.class, () -> AutoTaggingRegistry.registerFeatureType(null)); FeatureType featureType = mock(FeatureType.class); - when(featureType.getName()).thenReturn(TEST_FEATURE_TYPE); + when(featureType.getName()).thenReturn(FEATURE_TYPE_NAME); assertThrows(IllegalStateException.class, () -> AutoTaggingRegistry.registerFeatureType(featureType)); when(featureType.getName()).thenReturn(randomAlphaOfLength(MAX_FEATURE_TYPE_NAME_LENGTH + 1)); assertThrows(IllegalStateException.class, () -> AutoTaggingRegistry.registerFeatureType(featureType)); diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/RuleTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/RuleTests.java index 4ae89b496a1b3..5f20640a74f24 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/RuleTests.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/autotagging/RuleTests.java @@ -39,7 +39,6 @@ public class RuleTests extends AbstractSerializingTestCase { Set.of("value2") ); public static final String UPDATED_AT = "2025-02-24T07:42:10.123456Z"; - public static final String INVALID_CLASS = "invalid_class"; public static final String INVALID_ATTRIBUTE = "invalid_attribute"; public static final String INVALID_FEATURE = "invalid_feature"; @@ -92,7 +91,7 @@ public static class TestFeatureType implements FeatureType { public TestFeatureType() {} static { - INSTANCE.registerFeatureType(); + AutoTaggingRegistry.registerFeatureType(INSTANCE); } @Override @@ -114,11 +113,6 @@ public int getMaxCharLengthPerAttributeValue() { public Map getAllowedAttributesRegistry() { return ALLOWED_ATTRIBUTES; } - - @Override - public void registerFeatureType() { - AutoTaggingRegistry.registerFeatureType(INSTANCE); - } } static Rule buildRule( diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/service/IndexStoredRulePersistenceServiceTests.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/service/IndexStoredRulePersistenceServiceTests.java index 09a693dfd2b4e..eb054ea8124e5 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/service/IndexStoredRulePersistenceServiceTests.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/service/IndexStoredRulePersistenceServiceTests.java @@ -12,12 +12,15 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -25,6 +28,8 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.DocumentMissingException; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.rule.CreateRuleResponse; import org.opensearch.rule.DeleteRuleRequest; import org.opensearch.rule.GetRuleRequest; import org.opensearch.rule.GetRuleResponse; @@ -39,11 +44,26 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.mockito.ArgumentCaptor; import static org.opensearch.rule.XContentRuleParserTests.VALID_JSON; +import static org.opensearch.rule.utils.RuleTestUtils.ATTRIBUTE_MAP; +import static org.opensearch.rule.utils.RuleTestUtils.ATTRIBUTE_VALUE_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_ONE; +import static org.opensearch.rule.utils.RuleTestUtils.MockRuleFeatureType; import static org.opensearch.rule.utils.RuleTestUtils.TEST_INDEX_NAME; import static org.opensearch.rule.utils.RuleTestUtils._ID_ONE; import static org.mockito.ArgumentMatchers.any; @@ -51,7 +71,6 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -59,92 +78,208 @@ @SuppressWarnings("unchecked") public class IndexStoredRulePersistenceServiceTests extends OpenSearchTestCase { - public static final int MAX_VALUES_PER_PAGE = 50; + private static final int MAX_VALUES_PER_PAGE = 50; - public void testGetRuleByIdSuccess() { - GetRuleRequest getRuleRequest = mock(GetRuleRequest.class); - when(getRuleRequest.getId()).thenReturn(_ID_ONE); - when(getRuleRequest.getAttributeFilters()).thenReturn(new HashMap<>()); - QueryBuilder queryBuilder = mock(QueryBuilder.class); - RuleQueryMapper mockRuleQueryMapper = mock(RuleQueryMapper.class); - RuleEntityParser mockRuleEntityParser = mock(RuleEntityParser.class); - Rule mockRule = mock(Rule.class); + private Client client; + private ClusterService clusterService; + private RuleQueryMapper ruleQueryMapper; + private RuleEntityParser ruleEntityParser; + private SearchRequestBuilder searchRequestBuilder; + private RulePersistenceService rulePersistenceService; + private QueryBuilder queryBuilder; + private Rule rule; + + public void setUp() throws Exception { + super.setUp(); + searchRequestBuilder = mock(SearchRequestBuilder.class); + client = setUpMockClient(searchRequestBuilder); - when(mockRuleEntityParser.parse(anyString())).thenReturn(mockRule); - when(mockRuleQueryMapper.from(getRuleRequest)).thenReturn(queryBuilder); + rule = mock(Rule.class); + + clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(TEST_INDEX_NAME)).thenReturn(true); + + ruleQueryMapper = mock(RuleQueryMapper.class); + ruleEntityParser = mock(RuleEntityParser.class); + queryBuilder = mock(QueryBuilder.class); when(queryBuilder.filter(any())).thenReturn(queryBuilder); + when(ruleQueryMapper.from(any(GetRuleRequest.class))).thenReturn(queryBuilder); + when(ruleEntityParser.parse(anyString())).thenReturn(rule); + + rulePersistenceService = new IndexStoredRulePersistenceService( + TEST_INDEX_NAME, + client, + clusterService, + MAX_VALUES_PER_PAGE, + ruleEntityParser, + ruleQueryMapper + ); + } + + public void testCreateRuleOnExistingIndex() throws Exception { + CreateRuleRequest createRuleRequest = mock(CreateRuleRequest.class); + when(createRuleRequest.getRule()).thenReturn(rule); + when(rule.toXContent(any(), any())).thenAnswer(invocation -> invocation.getArgument(0)); + + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f)); + when(searchRequestBuilder.get()).thenReturn(searchResponse); + + IndexResponse indexResponse = mock(IndexResponse.class); + when(indexResponse.getId()).thenReturn(_ID_ONE); + ActionFuture future = mock(ActionFuture.class); + when(future.get()).thenReturn(indexResponse); + when(client.index(any(IndexRequest.class))).thenReturn(future); + + ActionListener listener = mock(ActionListener.class); + rulePersistenceService.createRule(createRuleRequest, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(CreateRuleResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + assertNotNull(responseCaptor.getValue().getRule()); + } + + public void testConcurrentCreateDuplicateRules() throws InterruptedException { + ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor(); + int threadCount = 10; + CountDownLatch latch = new CountDownLatch(threadCount); + Set storedAttributeMaps = ConcurrentHashMap.newKeySet(); - SearchRequestBuilder searchRequestBuilder = mock(SearchRequestBuilder.class); - Client client = setUpMockClient(searchRequestBuilder); + CreateRuleRequest createRuleRequest = mock(CreateRuleRequest.class); + when(rule.getAttributeMap()).thenReturn(ATTRIBUTE_MAP); + when(rule.getFeatureType()).thenReturn(MockRuleFeatureType.INSTANCE); + when(createRuleRequest.getRule()).thenReturn(rule); RulePersistenceService rulePersistenceService = new IndexStoredRulePersistenceService( TEST_INDEX_NAME, client, + clusterService, MAX_VALUES_PER_PAGE, - mockRuleEntityParser, - mockRuleQueryMapper - ); + ruleEntityParser, + ruleQueryMapper + ) { + @Override + public void createRule(CreateRuleRequest request, ActionListener listener) { + singleThreadExecutor.execute(() -> { + Rule rule = request.getRule(); + validateNoDuplicateRule(rule, new ActionListener() { + @Override + public void onResponse(Void unused) { + synchronized (storedAttributeMaps) { + storedAttributeMaps.add(MOCK_RULE_ATTRIBUTE_ONE.getName()); + } + listener.onResponse(new CreateRuleResponse("fake-id", rule)); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + latch.countDown(); + } + }); + }); + } + + public void validateNoDuplicateRule(Rule rule, ActionListener listener) { + synchronized (storedAttributeMaps) { + if (storedAttributeMaps.contains(MOCK_RULE_ATTRIBUTE_ONE.getName())) { + listener.onFailure(new IllegalArgumentException("Duplicate rule exists with attribute map")); + } else { + listener.onResponse(null); + } + } + } + }; + + class TestListener implements ActionListener { + final AtomicInteger successCount = new AtomicInteger(); + final AtomicInteger failureCount = new AtomicInteger(); + final List failures = Collections.synchronizedList(new ArrayList<>()); + + @Override + public void onResponse(CreateRuleResponse response) { + successCount.incrementAndGet(); + } + + @Override + public void onFailure(Exception e) { + failureCount.incrementAndGet(); + failures.add(e); + } + } + TestListener testListener = new TestListener(); + + for (int i = 0; i < threadCount; i++) { + new Thread(() -> rulePersistenceService.createRule(createRuleRequest, testListener)).start(); + } + boolean completed = latch.await(10, TimeUnit.SECONDS); + singleThreadExecutor.shutdown(); + assertTrue("All create calls should complete", completed); + assertEquals(1, testListener.successCount.get()); + assertEquals(threadCount - 1, testListener.failureCount.get()); + for (Exception e : testListener.failures) { + assertTrue(e instanceof IllegalArgumentException); + assertTrue(e.getMessage().contains("Duplicate rule")); + } + } + + public void testCreateDuplicateRule() { + CreateRuleRequest createRuleRequest = mock(CreateRuleRequest.class); + when(createRuleRequest.getRule()).thenReturn(rule); + when(rule.getAttributeMap()).thenReturn(Map.of(MOCK_RULE_ATTRIBUTE_ONE, Set.of(ATTRIBUTE_VALUE_ONE))); + when(rule.getFeatureType()).thenReturn(RuleTestUtils.MockRuleFeatureType.INSTANCE); SearchResponse searchResponse = mock(SearchResponse.class); - SearchHits searchHits = new SearchHits(new SearchHit[] { new SearchHit(1) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); - when(searchResponse.getHits()).thenReturn(searchHits); - SearchHit hit = searchHits.getHits()[0]; + SearchHit hit = new SearchHit(1); hit.sourceRef(new BytesArray(VALID_JSON)); + SearchHits searchHits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchRequestBuilder.get()).thenReturn(searchResponse); - ActionListener listener = mock(ActionListener.class); - - doAnswer((invocation) -> { - ActionListener actionListener = invocation.getArgument(0); - actionListener.onResponse(searchResponse); - return null; - }).when(searchRequestBuilder).execute(any(ActionListener.class)); + ActionListener listener = mock(ActionListener.class); + when(ruleEntityParser.parse(any(String.class))).thenReturn(rule); + rulePersistenceService.createRule(createRuleRequest, listener); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(failureCaptor.capture()); + } + public void testGetRuleByIdSuccess() { + GetRuleRequest getRuleRequest = mock(GetRuleRequest.class); + when(getRuleRequest.getId()).thenReturn(_ID_ONE); + when(getRuleRequest.getAttributeFilters()).thenReturn(new HashMap<>()); when(getRuleRequest.getFeatureType()).thenReturn(RuleTestUtils.MockRuleFeatureType.INSTANCE); - rulePersistenceService.getRule(getRuleRequest, listener); + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHit searchHit = new SearchHit(1); + searchHit.sourceRef(new BytesArray(VALID_JSON)); + SearchHits searchHits = new SearchHits(new SearchHit[] { searchHit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchRequestBuilder.get()).thenReturn(searchResponse); + + ActionListener listener = mock(ActionListener.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(GetRuleResponse.class); + rulePersistenceService.getRule(getRuleRequest, listener); verify(listener).onResponse(responseCaptor.capture()); GetRuleResponse response = responseCaptor.getValue(); - assertEquals(response.getRules().size(), 1); + assertEquals(1, response.getRules().size()); } public void testGetRuleByIdNotFound() { GetRuleRequest getRuleRequest = mock(GetRuleRequest.class); when(getRuleRequest.getId()).thenReturn(_ID_ONE); - QueryBuilder queryBuilder = mock(QueryBuilder.class); - RuleQueryMapper mockRuleQueryMapper = mock(RuleQueryMapper.class); - RuleEntityParser mockRuleEntityParser = mock(RuleEntityParser.class); - Rule mockRule = mock(Rule.class); - - when(mockRuleEntityParser.parse(anyString())).thenReturn(mockRule); - when(mockRuleQueryMapper.from(getRuleRequest)).thenReturn(queryBuilder); - when(queryBuilder.filter(any())).thenReturn(queryBuilder); - - SearchRequestBuilder searchRequestBuilder = mock(SearchRequestBuilder.class); - Client client = setUpMockClient(searchRequestBuilder); - - RulePersistenceService rulePersistenceService = new IndexStoredRulePersistenceService( - TEST_INDEX_NAME, - client, - MAX_VALUES_PER_PAGE, - mockRuleEntityParser, - mockRuleQueryMapper - ); + when(getRuleRequest.getFeatureType()).thenReturn(RuleTestUtils.MockRuleFeatureType.INSTANCE); SearchResponse searchResponse = mock(SearchResponse.class); + when(searchRequestBuilder.get()).thenReturn(searchResponse); when(searchResponse.getHits()).thenReturn(new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f)); - ActionListener listener = mock(ActionListener.class); - doAnswer(invocationOnMock -> { - ActionListener actionListener = invocationOnMock.getArgument(0); - actionListener.onResponse(searchResponse); - return null; - }).when(searchRequestBuilder).execute(any(ActionListener.class)); - - when(getRuleRequest.getFeatureType()).thenReturn(RuleTestUtils.MockRuleFeatureType.INSTANCE); rulePersistenceService.getRule(getRuleRequest, listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exceptionCaptor.capture()); Exception exception = exceptionCaptor.getValue(); @@ -174,26 +309,15 @@ private Client setUpMockClient(SearchRequestBuilder searchRequestBuilder) { public void testDeleteRule_successful() { String ruleId = "test-rule-id"; DeleteRuleRequest request = new DeleteRuleRequest(ruleId, RuleTestUtils.MockRuleFeatureType.INSTANCE); - - Client client = mock(Client.class); ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - RulePersistenceService rulePersistenceService = new IndexStoredRulePersistenceService( - TEST_INDEX_NAME, - client, - MAX_VALUES_PER_PAGE, - mock(RuleEntityParser.class), - mock(RuleQueryMapper.class) - ); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass( ActionListener.class ); - @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); rulePersistenceService.deleteRule(request, listener); @@ -212,26 +336,15 @@ public void testDeleteRule_successful() { public void testDeleteRule_notFound() { String ruleId = "missing-rule-id"; DeleteRuleRequest request = new DeleteRuleRequest(ruleId, RuleTestUtils.MockRuleFeatureType.INSTANCE); - - Client client = mock(Client.class); ThreadPool threadPool = mock(ThreadPool.class); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - RulePersistenceService rulePersistenceService = new IndexStoredRulePersistenceService( - TEST_INDEX_NAME, - client, - MAX_VALUES_PER_PAGE, - mock(RuleEntityParser.class), - mock(RuleQueryMapper.class) - ); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass( ActionListener.class ); - @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); rulePersistenceService.deleteRule(request, listener); diff --git a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/utils/RuleTestUtils.java b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/utils/RuleTestUtils.java index 6ec75e4b942ff..ac45b3e784446 100644 --- a/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/utils/RuleTestUtils.java +++ b/modules/autotagging-commons/common/src/test/java/org/opensearch/rule/utils/RuleTestUtils.java @@ -11,10 +11,14 @@ import org.opensearch.rule.autotagging.Attribute; import org.opensearch.rule.autotagging.AutoTaggingRegistry; import org.opensearch.rule.autotagging.FeatureType; +import org.opensearch.rule.autotagging.Rule; import java.util.Map; import java.util.Set; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + public class RuleTestUtils { public static final String _ID_ONE = "AgfUO5Ja9yfvhdONlYi3TQ=="; public static final String ATTRIBUTE_VALUE_ONE = "mock_attribute_one"; @@ -22,12 +26,63 @@ public class RuleTestUtils { public static final String DESCRIPTION_ONE = "description_1"; public static final String FEATURE_TYPE_NAME = "mock_feature_type"; public static final String TEST_INDEX_NAME = ".test_index_for_rule"; + public static final String INVALID_ATTRIBUTE = "invalid_attribute"; + + public static final String SEARCH_AFTER = "search_after"; + public static final String _ID_TWO = "G5iIq84j7eK1qIAAAAIH53=1"; + public static final String FEATURE_VALUE_ONE = "feature_value_one"; + public static final String FEATURE_VALUE_TWO = "feature_value_two"; + public static final String DESCRIPTION_TWO = "description_2"; + public static final String TIMESTAMP_ONE = "2024-01-26T08:58:57.558Z"; + public static final String TIMESTAMP_TWO = "2023-01-26T08:58:57.558Z"; + public static final Map> ATTRIBUTE_MAP = Map.of( MockRuleAttributes.MOCK_RULE_ATTRIBUTE_ONE, Set.of(ATTRIBUTE_VALUE_ONE) ); - public static final String INVALID_ATTRIBUTE = "invalid_attribute"; + public static final Rule ruleOne = Rule.builder() + .description(DESCRIPTION_ONE) + .featureType(RuleTestUtils.MockRuleFeatureType.INSTANCE) + .featureValue(FEATURE_VALUE_ONE) + .attributeMap(ATTRIBUTE_MAP) + .updatedAt(TIMESTAMP_ONE) + .build(); + + public static final Rule ruleTwo = Rule.builder() + .description(DESCRIPTION_TWO) + .featureType(RuleTestUtils.MockRuleFeatureType.INSTANCE) + .featureValue(FEATURE_VALUE_TWO) + .attributeMap(Map.of(RuleTestUtils.MockRuleAttributes.MOCK_RULE_ATTRIBUTE_TWO, Set.of(ATTRIBUTE_VALUE_TWO))) + .updatedAt(TIMESTAMP_TWO) + .build(); + + public static Map ruleMap() { + return Map.of(_ID_ONE, ruleOne, _ID_TWO, ruleTwo); + } + + public static void assertEqualRules(Map mapOne, Map mapTwo, boolean ruleUpdated) { + assertEquals(mapOne.size(), mapTwo.size()); + for (Map.Entry entry : mapOne.entrySet()) { + String id = entry.getKey(); + assertTrue(mapTwo.containsKey(id)); + Rule one = mapOne.get(id); + Rule two = mapTwo.get(id); + assertEqualRule(one, two, ruleUpdated); + } + } + + public static void assertEqualRule(Rule one, Rule two, boolean ruleUpdated) { + if (ruleUpdated) { + assertEquals(one.getDescription(), two.getDescription()); + assertEquals(one.getFeatureType(), two.getFeatureType()); + assertEquals(one.getFeatureValue(), two.getFeatureValue()); + assertEquals(one.getAttributeMap(), two.getAttributeMap()); + assertEquals(one.getAttributeMap(), two.getAttributeMap()); + } else { + assertEquals(one, two); + } + } public static class MockRuleFeatureType implements FeatureType { @@ -36,7 +91,7 @@ public static class MockRuleFeatureType implements FeatureType { private MockRuleFeatureType() {} static { - INSTANCE.registerFeatureType(); + AutoTaggingRegistry.registerFeatureType(INSTANCE); } @Override @@ -53,11 +108,6 @@ public Map getAllowedAttributesRegistry() { MockRuleAttributes.MOCK_RULE_ATTRIBUTE_TWO ); } - - @Override - public void registerFeatureType() { - AutoTaggingRegistry.registerFeatureType(INSTANCE); - } } public enum MockRuleAttributes implements Attribute { diff --git a/modules/autotagging-commons/spi/src/main/java/org/opensearch/rule/spi/RuleFrameworkExtension.java b/modules/autotagging-commons/spi/src/main/java/org/opensearch/rule/spi/RuleFrameworkExtension.java index 6197fa9e03093..5c34bc29efdda 100644 --- a/modules/autotagging-commons/spi/src/main/java/org/opensearch/rule/spi/RuleFrameworkExtension.java +++ b/modules/autotagging-commons/spi/src/main/java/org/opensearch/rule/spi/RuleFrameworkExtension.java @@ -9,6 +9,7 @@ package org.opensearch.rule.spi; import org.opensearch.rule.RulePersistenceService; +import org.opensearch.rule.RuleRoutingService; import org.opensearch.rule.autotagging.FeatureType; import java.util.function.Supplier; @@ -25,9 +26,14 @@ public interface RuleFrameworkExtension { Supplier getRulePersistenceServiceSupplier(); /** - * It tells the framework its FeatureType which can be used by Transport classes to handle the - * consumer specific persistence - * @return + * This method is used to flow implementation from consumer plugins into framework plugin + * @return the plugin specific implementation of RuleRoutingService + */ + Supplier getRuleRoutingServiceSupplier(); + + /** + * Flow implementation from consumer plugins into framework plugin + * @return the specific implementation of FeatureType */ - FeatureType getFeatureType(); + Supplier getFeatureTypeSupplier(); } diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleFrameworkPlugin.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleFrameworkPlugin.java index e386c0021475f..cd6197cf890f7 100644 --- a/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleFrameworkPlugin.java +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleFrameworkPlugin.java @@ -22,14 +22,20 @@ import org.opensearch.plugins.Plugin; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; +import org.opensearch.rule.action.CreateRuleAction; import org.opensearch.rule.action.DeleteRuleAction; import org.opensearch.rule.action.GetRuleAction; +import org.opensearch.rule.action.TransportCreateRuleAction; import org.opensearch.rule.action.TransportDeleteRuleAction; import org.opensearch.rule.action.TransportGetRuleAction; import org.opensearch.rule.autotagging.AutoTaggingRegistry; +import org.opensearch.rule.autotagging.FeatureType; +import org.opensearch.rule.rest.RestCreateRuleAction; import org.opensearch.rule.rest.RestDeleteRuleAction; import org.opensearch.rule.rest.RestGetRuleAction; import org.opensearch.rule.spi.RuleFrameworkExtension; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.FixedExecutorBuilder; import java.util.ArrayList; import java.util.Collection; @@ -38,15 +44,31 @@ /** * This plugin provides the central APIs which can provide CRUD support to all consumers of Rule framework + * Plugins that define custom rule logic must implement {@link RuleFrameworkExtension}, which ensures + * their feature types and persistence services are automatically registered and available to the Rule Framework. */ public class RuleFrameworkPlugin extends Plugin implements ExtensiblePlugin, ActionPlugin { + /** + * The name of the thread pool dedicated to rule execution. + */ + public static final String RULE_THREAD_POOL_NAME = "rule_serial_executor"; + /** + * The number of threads allocated in the rule execution thread pool. This is set to 1 to ensure serial execution. + */ + public static final int RULE_THREAD_COUNT = 1; + /** + * The maximum number of tasks that can be queued in the rule execution thread pool. + */ + public static final int RULE_QUEUE_SIZE = 100; + /** * constructor for RuleFrameworkPlugin */ public RuleFrameworkPlugin() {} private final RulePersistenceServiceRegistry rulePersistenceServiceRegistry = new RulePersistenceServiceRegistry(); + private final RuleRoutingServiceRegistry ruleRoutingServiceRegistry = new RuleRoutingServiceRegistry(); private final List ruleFrameworkExtensions = new ArrayList<>(); @Override @@ -55,7 +77,8 @@ public RuleFrameworkPlugin() {} ruleFrameworkExtensions.forEach(this::consumeFrameworkExtension); return List.of( new ActionPlugin.ActionHandler<>(GetRuleAction.INSTANCE, TransportGetRuleAction.class), - new ActionPlugin.ActionHandler<>(DeleteRuleAction.INSTANCE, TransportDeleteRuleAction.class) + new ActionPlugin.ActionHandler<>(DeleteRuleAction.INSTANCE, TransportDeleteRuleAction.class), + new ActionPlugin.ActionHandler<>(CreateRuleAction.INSTANCE, TransportCreateRuleAction.class) ); } @@ -69,12 +92,20 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - return List.of(new RestGetRuleAction(), new RestDeleteRuleAction()); + return List.of(new RestGetRuleAction(), new RestDeleteRuleAction(), new RestCreateRuleAction()); + } + + @Override + public List> getExecutorBuilders(Settings settings) { + return List.of(new FixedExecutorBuilder(settings, RULE_THREAD_POOL_NAME, RULE_THREAD_COUNT, RULE_QUEUE_SIZE, "rule-threadpool")); } @Override public Collection createGuiceModules() { - return List.of(b -> { b.bind(RulePersistenceServiceRegistry.class).toInstance(rulePersistenceServiceRegistry); }); + return List.of(b -> { + b.bind(RulePersistenceServiceRegistry.class).toInstance(rulePersistenceServiceRegistry); + b.bind(RuleRoutingServiceRegistry.class).toInstance(ruleRoutingServiceRegistry); + }); } @Override @@ -83,10 +114,10 @@ public void loadExtensions(ExtensionLoader loader) { } private void consumeFrameworkExtension(RuleFrameworkExtension ruleFrameworkExtension) { - AutoTaggingRegistry.registerFeatureType(ruleFrameworkExtension.getFeatureType()); - rulePersistenceServiceRegistry.register( - ruleFrameworkExtension.getFeatureType(), - ruleFrameworkExtension.getRulePersistenceServiceSupplier().get() - ); + FeatureType featureType = ruleFrameworkExtension.getFeatureTypeSupplier().get(); + AutoTaggingRegistry.registerFeatureType(featureType); + rulePersistenceServiceRegistry.register(featureType, ruleFrameworkExtension.getRulePersistenceServiceSupplier().get()); + ruleRoutingServiceRegistry.register(featureType, ruleFrameworkExtension.getRuleRoutingServiceSupplier().get()); + } } diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/RulePersistenceServiceRegistry.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/RulePersistenceServiceRegistry.java index 57f47f1a6ad0f..1479d1d3ded21 100644 --- a/modules/autotagging-commons/src/main/java/org/opensearch/rule/RulePersistenceServiceRegistry.java +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/RulePersistenceServiceRegistry.java @@ -41,7 +41,7 @@ public void register(FeatureType featureType, RulePersistenceService rulePersist */ public RulePersistenceService getRulePersistenceService(FeatureType featureType) { if (!rulePersistenceServices.containsKey(featureType.getName())) { - throw new IllegalArgumentException("Unknown feature type: " + featureType.getName()); + throw new IllegalArgumentException("Unknown feature type for persistence service: " + featureType.getName()); } return rulePersistenceServices.get(featureType.getName()); } diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleRoutingServiceRegistry.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleRoutingServiceRegistry.java new file mode 100644 index 0000000000000..d8bc29640a21e --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/RuleRoutingServiceRegistry.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule; + +import org.opensearch.rule.autotagging.FeatureType; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class manages implementations of {@link RuleRoutingService} + */ +public class RuleRoutingServiceRegistry { + private final Map ruleRoutingServices = new ConcurrentHashMap<>(); + + /** + * default constructor + */ + public RuleRoutingServiceRegistry() {} + + /** + * This method is used to register the concrete implementations of RuleRoutingService + * @param featureType + * @param ruleRoutingService + */ + public void register(FeatureType featureType, RuleRoutingService ruleRoutingService) { + if (ruleRoutingServices.put(featureType.getName(), ruleRoutingService) != null) { + throw new IllegalArgumentException("Duplicate rule routing service: " + featureType.getName()); + } + } + + /** + * It is used to get feature type specific {@link RuleRoutingService} implementation + * @param featureType - the type of feature to retrieve the routing service for + */ + public RuleRoutingService getRuleRoutingService(FeatureType featureType) { + if (!ruleRoutingServices.containsKey(featureType.getName())) { + throw new IllegalArgumentException("Unknown feature type for routing service: " + featureType.getName()); + } + return ruleRoutingServices.get(featureType.getName()); + } +} diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/CreateRuleAction.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/CreateRuleAction.java new file mode 100644 index 0000000000000..fe7b424b31ca4 --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/CreateRuleAction.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.action; + +import org.opensearch.action.ActionType; +import org.opensearch.rule.CreateRuleResponse; + +/** + * Action type for creating a Rule + * @opensearch.experimental + */ +public class CreateRuleAction extends ActionType { + + /** + * An instance of CreateRuleAction + */ + public static final CreateRuleAction INSTANCE = new CreateRuleAction(); + + /** + * Name for CreateRuleAction + */ + public static final String NAME = "cluster:admin/opensearch/rule/_create"; + + /** + * Default constructor + */ + private CreateRuleAction() { + super(NAME, CreateRuleResponse::new); + } +} diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/GetRuleAction.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/GetRuleAction.java index e59eabf682510..aa795d3aa4f72 100644 --- a/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/GetRuleAction.java +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/GetRuleAction.java @@ -18,7 +18,7 @@ public class GetRuleAction extends ActionType { /** - * An instance of GetWlmRuleAction + * An instance of GetRuleAction */ public static final GetRuleAction INSTANCE = new GetRuleAction(); diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/TransportCreateRuleAction.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/TransportCreateRuleAction.java new file mode 100644 index 0000000000000..f808cf7427cbc --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/action/TransportCreateRuleAction.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.action; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.TransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.rule.CreateRuleResponse; +import org.opensearch.rule.RulePersistenceService; +import org.opensearch.rule.RulePersistenceServiceRegistry; +import org.opensearch.rule.RuleRoutingServiceRegistry; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequestHandler; +import org.opensearch.transport.TransportService; + +import java.io.IOException; + +import static org.opensearch.rule.RuleFrameworkPlugin.RULE_THREAD_POOL_NAME; + +/** + * Transport action to create Rules + * @opensearch.experimental + */ +public class TransportCreateRuleAction extends TransportAction { + private final ThreadPool threadPool; + private final RuleRoutingServiceRegistry ruleRoutingServiceRegistry; + private final RulePersistenceServiceRegistry rulePersistenceServiceRegistry; + + /** + * Constructor for TransportCreateRuleAction + * @param transportService - a {@link TransportService} object + * @param actionFilters - a {@link ActionFilters} object + * @param threadPool - a {@link ThreadPool} object + * @param rulePersistenceServiceRegistry - a {@link RulePersistenceServiceRegistry} object + * @param ruleRoutingServiceRegistry - a {@link RuleRoutingServiceRegistry} object + */ + @Inject + public TransportCreateRuleAction( + TransportService transportService, + ThreadPool threadPool, + ActionFilters actionFilters, + RulePersistenceServiceRegistry rulePersistenceServiceRegistry, + RuleRoutingServiceRegistry ruleRoutingServiceRegistry + ) { + super(CreateRuleAction.NAME, actionFilters, transportService.getTaskManager()); + this.ruleRoutingServiceRegistry = ruleRoutingServiceRegistry; + this.threadPool = threadPool; + this.rulePersistenceServiceRegistry = rulePersistenceServiceRegistry; + + transportService.registerRequestHandler( + CreateRuleAction.NAME, + ThreadPool.Names.SAME, + CreateRuleRequest::new, + new TransportRequestHandler() { + @Override + public void messageReceived(CreateRuleRequest request, TransportChannel channel, Task task) { + executeLocally(request, ActionListener.wrap(response -> { + try { + channel.sendResponse(response); + } catch (IOException e) { + logger.error("Failed to send CreateRuleResponse to transport channel", e); + throw new TransportException("Fail to send", e); + } + }, exception -> { + try { + channel.sendResponse(exception); + } catch (IOException e) { + logger.error("Failed to send exception response to transport channel", e); + throw new TransportException("Fail to send", e); + } + })); + } + } + ); + } + + @Override + protected void doExecute(Task task, CreateRuleRequest request, ActionListener listener) { + ruleRoutingServiceRegistry.getRuleRoutingService(request.getRule().getFeatureType()).handleCreateRuleRequest(request, listener); + } + + /** + * Executes the create rule operation locally on the dedicated rule thread pool. + * @param request the CreateRuleRequest + * @param listener listener to handle response or failure + */ + private void executeLocally(CreateRuleRequest request, ActionListener listener) { + threadPool.executor(RULE_THREAD_POOL_NAME).execute(() -> { + final RulePersistenceService rulePersistenceService = rulePersistenceServiceRegistry.getRulePersistenceService( + request.getRule().getFeatureType() + ); + rulePersistenceService.createRule(request, listener); + }); + } +} diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/rest/RestCreateRuleAction.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/rest/RestCreateRuleAction.java new file mode 100644 index 0000000000000..7a5f45e95a0ba --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/rest/RestCreateRuleAction.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.rest; + +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.rule.CreateRuleResponse; +import org.opensearch.rule.action.CreateRuleAction; +import org.opensearch.rule.autotagging.FeatureType; +import org.opensearch.rule.autotagging.Rule.Builder; +import org.opensearch.transport.client.node.NodeClient; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.List; + +import static org.opensearch.rest.RestRequest.Method.PUT; +import static org.opensearch.rule.rest.RestGetRuleAction.FEATURE_TYPE; + +/** + * Rest action to create a Rule + * @opensearch.experimental + */ +public class RestCreateRuleAction extends BaseRestHandler { + /** + * constructor for RestCreateRuleAction + */ + public RestCreateRuleAction() {} + + @Override + public String getName() { + return "create_rule"; + } + + @Override + public List routes() { + return List.of(new RestHandler.Route(PUT, "_rules/{featureType}")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + final FeatureType featureType = FeatureType.from(request.param(FEATURE_TYPE)); + try (XContentParser parser = request.contentParser()) { + Builder builder = Builder.fromXContent(parser, featureType); + CreateRuleRequest createRuleRequest = new CreateRuleRequest(builder.updatedAt(Instant.now().toString()).build()); + + return channel -> client.execute(CreateRuleAction.INSTANCE, createRuleRequest, createRuleResponse(channel)); + } + } + + private RestResponseListener createRuleResponse(final RestChannel channel) { + return new RestResponseListener<>(channel) { + @Override + public RestResponse buildResponse(final CreateRuleResponse response) throws Exception { + return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS)); + } + }; + } +} diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java index 20a53c345edf7..dfe347007959a 100644 --- a/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java @@ -124,7 +124,7 @@ public enum WLMFeatureType implements FeatureType { WLM; static { - WLM.registerFeatureType(); + AutoTaggingRegistry.registerFeatureType(WLM); } @Override @@ -136,11 +136,6 @@ public String getName() { public Map getAllowedAttributesRegistry() { return Map.of("test_attribute", TestAttribute.TEST_ATTRIBUTE); } - - @Override - public void registerFeatureType() { - AutoTaggingRegistry.registerFeatureType(WLM); - } } public enum TestAttribute implements Attribute { diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/RuleFrameworkPluginTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/RuleFrameworkPluginTests.java index e9f0a68d35dea..4e11d12f9facb 100644 --- a/modules/autotagging-commons/src/test/java/org/opensearch/rule/RuleFrameworkPluginTests.java +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/RuleFrameworkPluginTests.java @@ -28,7 +28,7 @@ public class RuleFrameworkPluginTests extends OpenSearchTestCase { public void testGetActions() { List> handlers = plugin.getActions(); - assertEquals(2, handlers.size()); + assertEquals(3, handlers.size()); assertEquals(GetRuleAction.INSTANCE.name(), handlers.get(0).getAction().name()); } diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/RulePersistenceServiceRegistryTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/RulePersistenceServiceRegistryTests.java index 116ad7a730c58..57cc35ac4ac24 100644 --- a/modules/autotagging-commons/src/test/java/org/opensearch/rule/RulePersistenceServiceRegistryTests.java +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/RulePersistenceServiceRegistryTests.java @@ -15,8 +15,8 @@ import static org.mockito.Mockito.when; public class RulePersistenceServiceRegistryTests extends OpenSearchTestCase { - RulePersistenceServiceRegistry registry = new RulePersistenceServiceRegistry();; - FeatureType mockFeatureType = mock(FeatureType.class);; + RulePersistenceServiceRegistry registry = new RulePersistenceServiceRegistry(); + FeatureType mockFeatureType = mock(FeatureType.class); RulePersistenceService mockService = mock(RulePersistenceService.class); public void testRegisterAndGetService() { @@ -43,6 +43,6 @@ public void testGetRulePersistenceService_UnknownFeature() { IllegalArgumentException.class, () -> registry.getRulePersistenceService(mockFeatureType) ); - assertTrue(ex.getMessage().contains("Unknown feature type: unknown_feature")); + assertTrue(ex.getMessage().contains("Unknown feature type")); } } diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/action/CreateRuleActionTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/action/CreateRuleActionTests.java new file mode 100644 index 0000000000000..5941cd3e9a6a7 --- /dev/null +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/action/CreateRuleActionTests.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.action; + +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.test.OpenSearchTestCase; + +public class CreateRuleActionTests extends OpenSearchTestCase { + public void testGetName() { + assertEquals("cluster:admin/opensearch/rule/_create", CreateRuleAction.NAME); + } + + public void testCreateResponseReader() { + assertTrue(CreateRuleAction.INSTANCE.getResponseReader() instanceof Writeable.Reader); + } +} diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/action/TransportCreateRuleActionTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/action/TransportCreateRuleActionTests.java new file mode 100644 index 0000000000000..fa413044c5efc --- /dev/null +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/action/TransportCreateRuleActionTests.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.action; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.core.action.ActionListener; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.rule.CreateRuleResponse; +import org.opensearch.rule.RulePersistenceServiceRegistry; +import org.opensearch.rule.RuleRoutingService; +import org.opensearch.rule.RuleRoutingServiceRegistry; +import org.opensearch.rule.autotagging.FeatureType; +import org.opensearch.rule.autotagging.Rule; +import org.opensearch.rule.service.IndexStoredRulePersistenceService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.util.concurrent.ExecutorService; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@SuppressWarnings("unchecked") +public class TransportCreateRuleActionTests extends OpenSearchTestCase { + private TransportService transportService; + private ThreadPool threadPool; + private ActionFilters actionFilters; + private TransportCreateRuleAction action; + private FeatureType mockFeatureType; + private RuleRoutingServiceRegistry routingRegistry; + private RulePersistenceServiceRegistry persistenceRegistry; + private RuleRoutingService mockRoutingService; + + private final String testIndexName = "test-index"; + + public void setUp() throws Exception { + super.setUp(); + transportService = mock(TransportService.class); + threadPool = mock(ThreadPool.class); + actionFilters = mock(ActionFilters.class); + mockFeatureType = mock(FeatureType.class); + routingRegistry = mock(RuleRoutingServiceRegistry.class); + persistenceRegistry = mock(RulePersistenceServiceRegistry.class); + when(mockFeatureType.getName()).thenReturn("test_feature"); + mockRoutingService = mock(RuleRoutingService.class); + routingRegistry.register(mockFeatureType, mockRoutingService); + + ExecutorService executorService = mock(ExecutorService.class); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any()); + when(threadPool.executor(any())).thenReturn(executorService); + action = new TransportCreateRuleAction(transportService, threadPool, actionFilters, persistenceRegistry, routingRegistry); + } + + public void testExecution() { + IndexStoredRulePersistenceService persistenceService = mock(IndexStoredRulePersistenceService.class); + when(persistenceRegistry.getRulePersistenceService(mockFeatureType)).thenReturn(persistenceService); + when(routingRegistry.getRuleRoutingService(mockFeatureType)).thenReturn(mockRoutingService); + Rule rule = mock(Rule.class); + when(rule.getFeatureType()).thenReturn(mockFeatureType); + CreateRuleRequest request = new CreateRuleRequest(rule); + ActionListener listener = mock(ActionListener.class); + action.doExecute(null, request, listener); + verify(routingRegistry, times(1)).getRuleRoutingService(request.getRule().getFeatureType()); + } +} diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/rest/RestCreateRuleActionTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/rest/RestCreateRuleActionTests.java new file mode 100644 index 0000000000000..b6020c7b08e1c --- /dev/null +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/rest/RestCreateRuleActionTests.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.rest; + +import org.opensearch.test.OpenSearchTestCase; + +public class RestCreateRuleActionTests extends OpenSearchTestCase { + RestCreateRuleAction action = new RestCreateRuleAction();; + + public void testGetName() { + assertEquals("create_rule", action.getName()); + } + + public void testRoutes() { + var routes = action.routes(); + assertEquals(1, routes.size()); + assertTrue(routes.stream().anyMatch(r -> r.getPath().equals("_rules/{featureType}"))); + } +} diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java index dedc02a06e0c6..3ea3f5548ef5c 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java @@ -14,6 +14,7 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Module; +import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Setting; @@ -22,6 +23,7 @@ import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.discovery.SeedHostsProvider; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.indices.SystemIndexDescriptor; @@ -38,8 +40,11 @@ import org.opensearch.plugin.wlm.rest.RestGetWorkloadGroupAction; import org.opensearch.plugin.wlm.rest.RestUpdateWorkloadGroupAction; import org.opensearch.plugin.wlm.rule.WorkloadGroupFeatureType; +import org.opensearch.plugin.wlm.rule.WorkloadGroupFeatureValueValidator; +import org.opensearch.plugin.wlm.rule.WorkloadGroupRuleRoutingService; import org.opensearch.plugin.wlm.service.WorkloadGroupPersistenceService; import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.DiscoveryPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.repositories.RepositoriesService; @@ -47,6 +52,7 @@ import org.opensearch.rest.RestHandler; import org.opensearch.rule.InMemoryRuleProcessingService; import org.opensearch.rule.RulePersistenceService; +import org.opensearch.rule.RuleRoutingService; import org.opensearch.rule.autotagging.FeatureType; import org.opensearch.rule.service.IndexStoredRulePersistenceService; import org.opensearch.rule.spi.RuleFrameworkExtension; @@ -55,18 +61,21 @@ import org.opensearch.rule.storage.XContentRuleParser; import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; import org.opensearch.watcher.ResourceWatcherService; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.function.Supplier; /** * Plugin class for WorkloadManagement */ -public class WorkloadManagementPlugin extends Plugin implements ActionPlugin, SystemIndexPlugin, RuleFrameworkExtension { +public class WorkloadManagementPlugin extends Plugin implements ActionPlugin, SystemIndexPlugin, DiscoveryPlugin, RuleFrameworkExtension { + /** * The name of the index where rules are stored. */ @@ -75,8 +84,9 @@ public class WorkloadManagementPlugin extends Plugin implements ActionPlugin, Sy * The maximum number of rules allowed per GET request. */ public static final int MAX_RULES_PER_PAGE = 50; - - private final RulePersistenceServiceHolder rulePersistenceServiceHolder = new RulePersistenceServiceHolder(); + private static FeatureType featureType; + private static RulePersistenceService rulePersistenceService; + private static RuleRoutingService ruleRoutingService; private AutoTaggingActionFilter autoTaggingActionFilter; @@ -99,21 +109,30 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - RulePersistenceServiceHolder.rulePersistenceService = new IndexStoredRulePersistenceService( + featureType = new WorkloadGroupFeatureType(new WorkloadGroupFeatureValueValidator(clusterService)); + rulePersistenceService = new IndexStoredRulePersistenceService( INDEX_NAME, client, + clusterService, MAX_RULES_PER_PAGE, - new XContentRuleParser(WorkloadGroupFeatureType.INSTANCE), + new XContentRuleParser(featureType), new IndexBasedRuleQueryMapper() ); + ruleRoutingService = new WorkloadGroupRuleRoutingService(client, clusterService); InMemoryRuleProcessingService ruleProcessingService = new InMemoryRuleProcessingService( - WorkloadGroupFeatureType.INSTANCE, + featureType, DefaultAttributeValueStore::new ); autoTaggingActionFilter = new AutoTaggingActionFilter(ruleProcessingService, threadPool); return Collections.emptyList(); } + @Override + public Map> getSeedHostProviders(TransportService transportService, NetworkService networkService) { + ((WorkloadGroupRuleRoutingService) ruleRoutingService).setTransportService(transportService); + return Collections.emptyMap(); + } + @Override public List getActionFilters() { return List.of(autoTaggingActionFilter); @@ -131,7 +150,7 @@ public List getActionFilters() { @Override public Collection getSystemIndexDescriptors(Settings settings) { - return List.of(new SystemIndexDescriptor(INDEX_NAME, "System index used for storing rules")); + return List.of(new SystemIndexDescriptor(INDEX_NAME, "System index used for storing workload_group rules")); } @Override @@ -164,15 +183,16 @@ public Collection createGuiceModules() { @Override public Supplier getRulePersistenceServiceSupplier() { - return () -> RulePersistenceServiceHolder.rulePersistenceService; + return () -> rulePersistenceService; } @Override - public FeatureType getFeatureType() { - return WorkloadGroupFeatureType.INSTANCE; + public Supplier getRuleRoutingServiceSupplier() { + return () -> ruleRoutingService; } - static class RulePersistenceServiceHolder { - private static RulePersistenceService rulePersistenceService; + @Override + public Supplier getFeatureTypeSupplier() { + return () -> featureType; } } diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureType.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureType.java index 1dce67ee4e72d..fc9dfa3136277 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureType.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureType.java @@ -10,8 +10,8 @@ import org.opensearch.rule.RuleAttribute; import org.opensearch.rule.autotagging.Attribute; -import org.opensearch.rule.autotagging.AutoTaggingRegistry; import org.opensearch.rule.autotagging.FeatureType; +import org.opensearch.rule.autotagging.FeatureValueValidator; import java.util.Map; @@ -20,10 +20,6 @@ * @opensearch.experimental */ public class WorkloadGroupFeatureType implements FeatureType { - /** - * The instance for WorkloadGroupFeatureType - */ - public static final WorkloadGroupFeatureType INSTANCE = new WorkloadGroupFeatureType(); /** * Name for WorkloadGroupFeatureType */ @@ -34,8 +30,15 @@ public class WorkloadGroupFeatureType implements FeatureType { RuleAttribute.INDEX_PATTERN.getName(), RuleAttribute.INDEX_PATTERN ); + private final FeatureValueValidator featureValueValidator; - private WorkloadGroupFeatureType() {} + /** + * constructor for WorkloadGroupFeatureType + * @param featureValueValidator + */ + public WorkloadGroupFeatureType(FeatureValueValidator featureValueValidator) { + this.featureValueValidator = featureValueValidator; + } @Override public String getName() { @@ -58,7 +61,7 @@ public Map getAllowedAttributesRegistry() { } @Override - public void registerFeatureType() { - AutoTaggingRegistry.registerFeatureType(INSTANCE); + public FeatureValueValidator getFeatureValueValidator() { + return featureValueValidator; } } diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureValueValidator.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureValueValidator.java new file mode 100644 index 0000000000000..0ea7621943615 --- /dev/null +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureValueValidator.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.wlm.rule; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.rule.autotagging.FeatureValueValidator; + +/** + * Validator for the workload_group feature type + * @opensearch.experimental + */ +public class WorkloadGroupFeatureValueValidator implements FeatureValueValidator { + private final ClusterService clusterService; + private final Logger logger = LogManager.getLogger(WorkloadGroupFeatureValueValidator.class); + + /** + * constructor for WorkloadGroupFeatureValueValidator + * @param clusterService + */ + public WorkloadGroupFeatureValueValidator(ClusterService clusterService) { + this.clusterService = clusterService; + } + + @Override + public void validate(String featureValue) { + if (!clusterService.state().metadata().workloadGroups().containsKey(featureValue)) { + logger.error("{} is not a valid workload group id.", featureValue); + throw new ResourceNotFoundException(featureValue + " is not a valid workload group id."); + } + } +} diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupRuleRoutingService.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupRuleRoutingService.java new file mode 100644 index 0000000000000..a70482eb40186 --- /dev/null +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/WorkloadGroupRuleRoutingService.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.plugin.wlm.rule; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.routing.IndexShardRoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.plugin.wlm.WorkloadManagementPlugin; +import org.opensearch.rule.CreateRuleRequest; +import org.opensearch.rule.CreateRuleResponse; +import org.opensearch.rule.RuleRoutingService; +import org.opensearch.rule.action.CreateRuleAction; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import java.util.Map; +import java.util.Optional; + +/** + * Service responsible for routing CreateRule requests to the correct node based on primary shard ownership. + * @opensearch.experimental + */ +public class WorkloadGroupRuleRoutingService implements RuleRoutingService { + private final Client client; + private final ClusterService clusterService; + private TransportService transportService; + private static final Logger logger = LogManager.getLogger(WorkloadGroupRuleRoutingService.class); + private static final Map indexSettings = Map.of("index.number_of_shards", 1, "index.auto_expand_replicas", "0-all"); + + /** + * Constructor for WorkloadGroupRuleRoutingService + * @param client + * @param clusterService + */ + public WorkloadGroupRuleRoutingService(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + /** + * Set {@link TransportService} for WorkloadGroupRuleRoutingService + * @param transportService + */ + public void setTransportService(TransportService transportService) { + this.transportService = transportService; + } + + @Override + public void handleCreateRuleRequest(CreateRuleRequest request, ActionListener listener) { + String indexName = WorkloadManagementPlugin.INDEX_NAME; + + try (ThreadContext.StoredContext ctx = client.threadPool().getThreadContext().stashContext()) { + if (clusterService.state().metadata().hasIndex(indexName)) { + routeRequest(request, listener, indexName); + return; + } + createIndex(indexName, new ActionListener<>() { + @Override + public void onResponse(CreateIndexResponse response) { + if (!response.isAcknowledged()) { + logger.error("Failed to create index " + indexName); + listener.onFailure(new IllegalStateException(indexName + " index creation not acknowledged")); + } else { + routeRequest(request, listener, indexName); + } + } + + @Override + public void onFailure(Exception e) { + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof ResourceAlreadyExistsException) { + routeRequest(request, listener, indexName); + } else { + logger.error("Failed to create index {}: {}", indexName, e.getMessage()); + listener.onFailure(e); + } + } + }); + } + } + + /** + * Creates the backing index if it does not exist, then runs the given success callback. + * @param indexName the name of the index to create + * @param listener listener to handle failures + */ + private void createIndex(String indexName, ActionListener listener) { + final CreateIndexRequest createRequest = new CreateIndexRequest(indexName).settings(indexSettings); + client.admin().indices().create(createRequest, listener); + } + + /** + * Routes the CreateRuleRequest to the primary shard node for the given index. + * Executes locally if the current node is the primary. + * @param request the CreateRuleRequest + * @param listener listener to handle response or failure + * @param indexName the index name used to find the primary shard node + */ + private void routeRequest(CreateRuleRequest request, ActionListener listener, String indexName) { + Optional optionalPrimaryNode = getPrimaryShardNode(indexName); + if (optionalPrimaryNode.isEmpty()) { + listener.onFailure(new IllegalStateException("Primary node for index [" + indexName + "] not found")); + return; + } + DiscoveryNode primaryNode = optionalPrimaryNode.get(); + transportService.sendRequest( + primaryNode, + CreateRuleAction.NAME, + request, + new ActionListenerResponseHandler<>(listener, CreateRuleResponse::new) + ); + } + + /** + * Retrieves the discovery node that holds the primary shard for the given index. + * @param indexName the index name + */ + private Optional getPrimaryShardNode(String indexName) { + ClusterState state = clusterService.state(); + return Optional.ofNullable(state.getRoutingTable().index(indexName)) + .map(table -> table.shard(0)) + .map(IndexShardRoutingTable::primaryShard) + .filter(ShardRouting::assignedToNode) + .map(shard -> state.nodes().get(shard.currentNodeId())); + } +} diff --git a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java index ec458bfb9b9f1..45802c28ee458 100644 --- a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java +++ b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java @@ -92,9 +92,6 @@ public String getName() { public Map getAllowedAttributesRegistry() { return Map.of("test_attribute", TestAttribute.TEST_ATTRIBUTE); } - - @Override - public void registerFeatureType() {} } public enum TestAttribute implements Attribute { diff --git a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/WorkloadManagementPluginTests.java b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/WorkloadManagementPluginTests.java index 6838102ac3bf1..458707fb1aee1 100644 --- a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/WorkloadManagementPluginTests.java +++ b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/WorkloadManagementPluginTests.java @@ -98,7 +98,20 @@ public void testGetSystemIndexDescriptorsReturnsCorrectDescriptor() { } public void testGetFeatureTypeReturnsWorkloadGroupFeatureType() { - FeatureType featureType = plugin.getFeatureType(); + plugin.createComponents( + mock(Client.class), + mock(ClusterService.class), + mock(ThreadPool.class), + mock(ResourceWatcherService.class), + mock(ScriptService.class), + mock(NamedXContentRegistry.class), + mock(Environment.class), + null, + mock(NamedWriteableRegistry.class), + mock(IndexNameExpressionResolver.class), + () -> mock(RepositoriesService.class) + ); + FeatureType featureType = plugin.getFeatureTypeSupplier().get(); assertEquals("workload_group", featureType.getName()); } diff --git a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureTypeTests.java b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureTypeTests.java index c2728a36e9196..a55e345fd56da 100644 --- a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureTypeTests.java +++ b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/rule/WorkloadGroupFeatureTypeTests.java @@ -8,6 +8,7 @@ package org.opensearch.plugin.wlm.rule; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.rule.RuleAttribute; import org.opensearch.rule.autotagging.Attribute; import org.opensearch.rule.autotagging.AutoTaggingRegistry; @@ -15,8 +16,10 @@ import java.util.Map; +import static org.mockito.Mockito.mock; + public class WorkloadGroupFeatureTypeTests extends OpenSearchTestCase { - WorkloadGroupFeatureType featureType = WorkloadGroupFeatureType.INSTANCE; + WorkloadGroupFeatureType featureType = new WorkloadGroupFeatureType(new WorkloadGroupFeatureValueValidator(mock(ClusterService.class))); public void testGetName_returnsCorrectName() { assertEquals("workload_group", featureType.getName()); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 25a43aa635127..22ccfb6e438df 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -1362,6 +1362,7 @@ protected Node(final Environment initialEnvironment, Collection clas clusterManagerMetrics, remoteClusterStateService ); + final SearchPipelineService searchPipelineService = new SearchPipelineService( clusterService, threadPool,