diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 0ce2362a77..e0042c7d80 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -103,6 +103,10 @@ public void initModel(MLModel model, Map params, Encryptor encry Connector connector = model.getConnector().cloneConnector(); connector .decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, model.getTenantId()), model.getTenantId()); + // This situation can only happen for inline connector where we don't provide tenant id. + if (connector.getTenantId() == null && model.getTenantId() != null) { + connector.setTenantId(model.getTenantId()); + } this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 8bce0d6394..f01e81475b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -6,10 +6,10 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,7 +39,6 @@ import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.MLStaticMockBase; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.encryptor.EncryptorImpl; import com.google.common.collect.ImmutableMap; @@ -64,7 +63,9 @@ public class RemoteModelTest extends MLStaticMockBase { public void setUp() { MockitoAnnotations.openMocks(this); remoteModel = new RemoteModel(); - encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); + + encryptor = mock(Encryptor.class); + when(encryptor.decrypt(any(), any())).thenReturn("test_api_key"); } @Test @@ -189,7 +190,7 @@ public void initModel_NullHeader() { when(mlModel.getConnector()).thenReturn(connector); remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); Map decryptedHeaders = connector.getDecryptedHeaders(); - Assert.assertNull(decryptedHeaders); + assertNull(decryptedHeaders); } @Test @@ -200,12 +201,59 @@ public void initModel_WithHeader() { Map decryptedHeaders = connector.getDecryptedHeaders(); RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); Assert.assertNotNull(executor); - Assert.assertNull(decryptedHeaders); + assertNull(decryptedHeaders); Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); remoteModel.close(); - Assert.assertNull(remoteModel.getConnectorExecutor()); + assertNull(remoteModel.getConnectorExecutor()); + } + + @Test + public void initModel_setsTenantIdOnClonedConnector_whenMissing() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn("tenantId"); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + remoteModel.close(); + assertNull(connector.getTenantId()); + assertEquals("tenantId", executor.getConnector().getTenantId()); + } + + @Test + public void initModel_bothTenantIdsNull() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn(null); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + assertNull(connector.getTenantId()); + assertNull(executor.getConnector().getTenantId()); + } + + @Test + public void initModel_connectorHasTenantId() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + connector.setTenantId("connectorTenantId"); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn(null); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + assertEquals("connectorTenantId", connector.getTenantId()); + assertEquals("connectorTenantId", executor.getConnector().getTenantId()); + } + + @Test + public void initModel_bothHaveTenantIds() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + connector.setTenantId("connectorTenantId"); + when(mlModel.getConnector()).thenReturn(connector); + when(mlModel.getTenantId()).thenReturn("modelTenantId"); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + assertEquals("connectorTenantId", connector.getTenantId()); + assertEquals("connectorTenantId", executor.getConnector().getTenantId()); } private Connector createConnector(Map headers) { @@ -222,7 +270,7 @@ private Connector createConnector(Map headers) { .name("test connector") .protocol(ConnectorProtocols.HTTP) .version("1") - .credential(ImmutableMap.of("key", encryptor.encrypt("test_api_key", null))) + .credential(ImmutableMap.of("key", "dummy-encrypted-value")) .actions(Arrays.asList(predictAction)) .build(); return connector;