Skip to content

Onboards to centralized resource access control mechanism for ml-model-group #3715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ dependencies {
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2'
testImplementation "org.opensearch.test:framework:${opensearch_version}"

compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}"

compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0'
compileOnly group: 'org.json', name: 'json', version: '20231013'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import org.opensearch.security.spi.resources.client.ResourceSharingClient;

/**
* Accessor for resource sharing client
*/
public class ResourceSharingClientAccessor {
private ResourceSharingClient CLIENT;

private static ResourceSharingClientAccessor resourceSharingClientAccessor;

private ResourceSharingClientAccessor() {}

public static ResourceSharingClientAccessor getInstance() {
if (resourceSharingClientAccessor == null) {
resourceSharingClientAccessor = new ResourceSharingClientAccessor();
}

return resourceSharingClientAccessor;
}

/**
* Set the resource sharing client
*/
public void setResourceSharingClient(ResourceSharingClient client) {
resourceSharingClientAccessor.CLIENT = client;
}

/**
* Get the resource sharing client
*/
public ResourceSharingClient getResourceSharingClient() {
return resourceSharingClientAccessor.CLIENT;
}

}
4 changes: 3 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ opensearchplugin {
name 'opensearch-ml'
description 'machine learning plugin for opensearch'
classname 'org.opensearch.ml.plugin.MachineLearningPlugin'
extendedPlugins = ['opensearch-job-scheduler']
extendedPlugins = ['opensearch-job-scheduler', 'opensearch-security;optional=true']
}

configurations {
Expand All @@ -72,6 +72,8 @@ dependencies {

zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}"

implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
// Multi-tenant SDK Client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
Expand Down Expand Up @@ -66,6 +67,7 @@
public class CreateControllerTransportAction extends HandledTransportAction<ActionRequest, MLCreateControllerResponse> {
MLIndicesHandler mlIndicesHandler;
Client client;
Settings settings;
MLModelManager mlModelManager;
ClusterService clusterService;
MLModelCacheHelper mlModelCacheHelper;
Expand All @@ -78,6 +80,7 @@ public CreateControllerTransportAction(
ActionFilters actionFilters,
MLIndicesHandler mlIndicesHandler,
Client client,
Settings settings,
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
Expand All @@ -87,6 +90,7 @@ public CreateControllerTransportAction(
super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new);
this.mlIndicesHandler = mlIndicesHandler;
this.client = client;
this.settings = settings;
this.mlModelManager = mlModelManager;
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
Expand All @@ -112,7 +116,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea
Boolean isHidden = mlModel.getIsHidden();
if (functionName == TEXT_EMBEDDING || functionName == REMOTE) {
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> {
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> {
if (hasPermission) {
if (mlModel.getModelState() != MLModelState.DEPLOYING) {
indexAndCreateController(mlModel, controller, wrappedListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -55,6 +56,7 @@
@FieldDefaults(level = AccessLevel.PRIVATE)
public class DeleteControllerTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> {
Client client;
Settings settings;
NamedXContentRegistry xContentRegistry;
ClusterService clusterService;
MLModelManager mlModelManager;
Expand All @@ -67,6 +69,7 @@ public DeleteControllerTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
Settings settings,
NamedXContentRegistry xContentRegistry,
ClusterService clusterService,
MLModelManager mlModelManager,
Expand Down Expand Up @@ -98,7 +101,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> {
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> {
if (hasPermission) {
mlModelManager
.getController(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -48,6 +49,7 @@
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GetControllerTransportAction extends HandledTransportAction<ActionRequest, MLControllerGetResponse> {
Client client;
Settings settings;
NamedXContentRegistry xContentRegistry;
ClusterService clusterService;
MLModelManager mlModelManager;
Expand All @@ -59,6 +61,7 @@ public GetControllerTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
Settings settings,
NamedXContentRegistry xContentRegistry,
ClusterService clusterService,
MLModelManager mlModelManager,
Expand All @@ -67,6 +70,7 @@ public GetControllerTransportAction(
) {
super(MLControllerGetAction.NAME, transportService, actionFilters, MLControllerGetRequest::new);
this.client = client;
this.settings = settings;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -96,34 +100,40 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCont
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> {
if (hasPermission) {
wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build());
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
getErrorMessage(
"User doesn't have privilege to perform this operation on this model controller.",
modelId,
isHidden
),
RestStatus.FORBIDDEN
)
.validateModelGroupAccess(
user,
mlModel.getModelGroupId(),
client,
settings,
ActionListener.wrap(hasPermission -> {
if (hasPermission) {
wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build());
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
getErrorMessage(
"User doesn't have privilege to perform this operation on this model controller.",
modelId,
isHidden
),
RestStatus.FORBIDDEN
)
);
}
}, exception -> {
log
.error(
getErrorMessage(
"Permission denied: Unable to create the model controller for the given model.",
modelId,
isHidden
),
exception
);
}
}, exception -> {
log
.error(
getErrorMessage(
"Permission denied: Unable to create the model controller for the given model.",
modelId,
isHidden
),
exception
);
wrappedListener.onFailure(exception);
}));
wrappedListener.onFailure(exception);
})
);
},
e -> wrappedListener
.onFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
Expand Down Expand Up @@ -60,6 +61,7 @@
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class UpdateControllerTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
Settings settings;
MLModelManager mlModelManager;
MLModelCacheHelper mlModelCacheHelper;
ClusterService clusterService;
Expand All @@ -71,6 +73,7 @@ public UpdateControllerTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
Settings settings,
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
Expand All @@ -79,6 +82,7 @@ public UpdateControllerTransportAction(
) {
super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new);
this.client = client;
this.settings = settings;
this.mlModelManager = mlModelManager;
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
Expand All @@ -104,7 +108,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
Boolean isHidden = mlModel.getIsHidden();
if (functionName == TEXT_EMBEDDING || functionName == REMOTE) {
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> {
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> {
if (hasPermission) {
mlModelManager.getController(modelId, ActionListener.wrap(controller -> {
boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
}
} else {
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -65,17 +66,20 @@ public class MLSearchHandler {
private ModelAccessControlHelper modelAccessControlHelper;

private ClusterService clusterService;
private Settings settings;

public MLSearchHandler(
Client client,
NamedXContentRegistry xContentRegistry,
ModelAccessControlHelper modelAccessControlHelper,
ClusterService clusterService
ClusterService clusterService,
Settings settings
) {
this.modelAccessControlHelper = modelAccessControlHelper;
this.client = client;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
this.settings = settings;
}

/**
Expand Down Expand Up @@ -144,7 +148,7 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId,
.searchDataObjectAsync(searchDataObjectRequest)
.whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrapperListener));
} else {
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user, settings);
SearchRequest modelGroupSearchRequest = new SearchRequest();
sourceBuilder.fetchSource(new String[] { MLModelGroup.MODEL_GROUP_ID_FIELD, }, null);
sourceBuilder.size(10000);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -57,6 +58,7 @@ public class DeleteModelGroupTransportAction extends HandledTransportAction<Acti
final SdkClient sdkClient;
final NamedXContentRegistry xContentRegistry;
final ClusterService clusterService;
final Settings settings;

final ModelAccessControlHelper modelAccessControlHelper;
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
Expand All @@ -66,6 +68,7 @@ public DeleteModelGroupTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
Settings settings,
SdkClient sdkClient,
NamedXContentRegistry xContentRegistry,
ClusterService clusterService,
Expand All @@ -74,6 +77,7 @@ public DeleteModelGroupTransportAction(
) {
super(MLModelGroupDeleteAction.NAME, transportService, actionFilters, MLModelGroupDeleteRequest::new);
this.client = client;
this.settings = settings;
this.sdkClient = sdkClient;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
Expand All @@ -93,6 +97,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Remove this feature flag check once feature is GA, as it will be enabled by default
validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener);
}
}
Expand All @@ -107,6 +112,7 @@ private void validateAndDeleteModelGroup(String modelGroupId, String tenantId, A
modelGroupId,
client,
sdkClient,
settings,
ActionListener
.wrap(
hasAccess -> handleAccessValidation(hasAccess, modelGroupId, tenantId, listener),
Expand Down
Loading
Loading