From fadbdcc627b28630c3f29da9817bf6a6b7cb1b99 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Tue, 6 May 2025 16:23:34 -0700 Subject: [PATCH 1/2] adding tenantId to the connector executor when this is inline connector Signed-off-by: Dhrubo Saha --- .../engine/algorithms/remote/RemoteModel.java | 4 +++ .../algorithms/remote/RemoteModelTest.java | 27 ++++++++++++++----- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 0ce2362a77..e0042c7d80 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -103,6 +103,10 @@ public void initModel(MLModel model, Map params, Encryptor encry Connector connector = model.getConnector().cloneConnector(); connector .decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, model.getTenantId()), model.getTenantId()); + // This situation can only happen for inline connector where we don't provide tenant id. + if (connector.getTenantId() == null && model.getTenantId() != null) { + connector.setTenantId(model.getTenantId()); + } this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 8bce0d6394..9b80544d03 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -6,10 +6,10 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,7 +39,6 @@ import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.MLStaticMockBase; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.encryptor.EncryptorImpl; import com.google.common.collect.ImmutableMap; @@ -64,7 +63,9 @@ public class RemoteModelTest extends MLStaticMockBase { public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); - encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); + + encryptor = mock(Encryptor.class); + when(encryptor.decrypt(any(), any())).thenReturn("test_api_key"); } @Test @@ -189,7 +190,7 @@ public void initModel_NullHeader() { when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); Map decryptedHeaders = connector.getDecryptedHeaders(); - Assert.assertNull(decryptedHeaders); + assertNull(decryptedHeaders); } @Test @@ -200,12 +201,24 @@ public void initModel_WithHeader() { Map decryptedHeaders = connector.getDecryptedHeaders(); RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); Assert.assertNotNull(executor); - Assert.assertNull(decryptedHeaders); + assertNull(decryptedHeaders); Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); remoteModel.close(); - Assert.assertNull(remoteModel.getConnectorExecutor()); + assertNull(remoteModel.getConnectorExecutor()); + } + + @Test + public void initModel_setsTenantIdOnClonedConnector_whenMissing() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn("tenantId"); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + remoteModel.close(); + assertNull(connector.getTenantId()); + assertEquals("tenantId", executor.getConnector().getTenantId()); } private Connector createConnector(Map headers) { @@ -222,7 +235,7 @@ private Connector createConnector(Map headers) { .name("test connector") .protocol(ConnectorProtocols.HTTP) .version("1") - .credential(ImmutableMap.of("key", encryptor.encrypt("test_api_key", null))) + .credential(ImmutableMap.of("key", "dummy-encrypted-value")) .actions(Arrays.asList(predictAction)) .build(); return connector; From f71d879af396f5f66ae8df1c9baad98e3309674d Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Wed, 7 May 2025 10:42:26 -0700 Subject: [PATCH 2/2] added more unit tests Signed-off-by: Dhrubo Saha --- .../algorithms/remote/RemoteModelTest.java | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 9b80544d03..f01e81475b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -221,6 +221,41 @@ public void initModel_setsTenantIdOnClonedConnector_whenMissing() { assertEquals("tenantId", executor.getConnector().getTenantId()); } + @Test + public void initModel_bothTenantIdsNull() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn(null); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + assertNull(connector.getTenantId()); + assertNull(executor.getConnector().getTenantId()); + } + + @Test + public void initModel_connectorHasTenantId() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + connector.setTenantId("connectorTenantId"); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn(null); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + assertEquals("connectorTenantId", connector.getTenantId()); + assertEquals("connectorTenantId", executor.getConnector().getTenantId()); + } + + @Test + public void initModel_bothHaveTenantIds() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + connector.setTenantId("connectorTenantId"); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn("modelTenantId"); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + assertEquals("connectorTenantId", connector.getTenantId()); + assertEquals("connectorTenantId", executor.getConnector().getTenantId()); + } + private Connector createConnector(Map headers) { ConnectorAction predictAction = ConnectorAction .builder()