Skip to content

Commit 71e6365

Browse files
authored
adding tenantId to the connector executor when this is inline connector (#3837)
* adding tenantId to the connector executor when this is inline connector Signed-off-by: Dhrubo Saha <[email protected]> * added more unit tests Signed-off-by: Dhrubo Saha <[email protected]> --------- Signed-off-by: Dhrubo Saha <[email protected]>
1 parent dd887bc commit 71e6365

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

+4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
103103
Connector connector = model.getConnector().cloneConnector();
104104
connector
105105
.decrypt(PREDICT.name(), (credential, tenantId) -> encryptor.decrypt(credential, model.getTenantId()), model.getTenantId());
106+
// This situation can only happen for inline connector where we don't provide tenant id.
107+
if (connector.getTenantId() == null && model.getTenantId() != null) {
108+
connector.setTenantId(model.getTenantId());
109+
}
106110
this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
107111
this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
108112
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java

+55-7
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
package org.opensearch.ml.engine.algorithms.remote;
77

88
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNull;
910
import static org.mockito.ArgumentMatchers.any;
1011
import static org.mockito.Mockito.doThrow;
1112
import static org.mockito.Mockito.mock;
12-
import static org.mockito.Mockito.spy;
1313
import static org.mockito.Mockito.verify;
1414
import static org.mockito.Mockito.when;
1515

@@ -39,7 +39,6 @@
3939
import org.opensearch.ml.engine.MLEngineClassLoader;
4040
import org.opensearch.ml.engine.MLStaticMockBase;
4141
import org.opensearch.ml.engine.encryptor.Encryptor;
42-
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
4342

4443
import com.google.common.collect.ImmutableMap;
4544

@@ -64,7 +63,9 @@ public class RemoteModelTest extends MLStaticMockBase {
6463
public void setUp() {
6564
MockitoAnnotations.openMocks(this);
6665
remoteModel = new RemoteModel();
67-
encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="));
66+
67+
encryptor = mock(Encryptor.class);
68+
when(encryptor.decrypt(any(), any())).thenReturn("test_api_key");
6869
}
6970

7071
@Test
@@ -189,7 +190,7 @@ public void initModel_NullHeader() {
189190
when(mlModel.getConnector()).thenReturn(connector);
190191
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
191192
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
192-
Assert.assertNull(decryptedHeaders);
193+
assertNull(decryptedHeaders);
193194
}
194195

195196
@Test
@@ -200,12 +201,59 @@ public void initModel_WithHeader() {
200201
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
201202
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
202203
Assert.assertNotNull(executor);
203-
Assert.assertNull(decryptedHeaders);
204+
assertNull(decryptedHeaders);
204205
Assert.assertNotNull(executor.getConnector().getDecryptedHeaders());
205206
assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
206207
assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));
207208
remoteModel.close();
208-
Assert.assertNull(remoteModel.getConnectorExecutor());
209+
assertNull(remoteModel.getConnectorExecutor());
210+
}
211+
212+
@Test
213+
public void initModel_setsTenantIdOnClonedConnector_whenMissing() {
214+
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
215+
when(mlModel.getConnector()).thenReturn(connector);
216+
when(mlModel.getTenantId()).thenReturn("tenantId");
217+
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
218+
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
219+
remoteModel.close();
220+
assertNull(connector.getTenantId());
221+
assertEquals("tenantId", executor.getConnector().getTenantId());
222+
}
223+
224+
@Test
225+
public void initModel_bothTenantIdsNull() {
226+
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
227+
when(mlModel.getConnector()).thenReturn(connector);
228+
when(mlModel.getTenantId()).thenReturn(null);
229+
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
230+
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
231+
assertNull(connector.getTenantId());
232+
assertNull(executor.getConnector().getTenantId());
233+
}
234+
235+
@Test
236+
public void initModel_connectorHasTenantId() {
237+
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
238+
connector.setTenantId("connectorTenantId");
239+
when(mlModel.getConnector()).thenReturn(connector);
240+
when(mlModel.getTenantId()).thenReturn(null);
241+
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
242+
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
243+
assertEquals("connectorTenantId", connector.getTenantId());
244+
assertEquals("connectorTenantId", executor.getConnector().getTenantId());
245+
}
246+
247+
@Test
248+
public void initModel_bothHaveTenantIds() {
249+
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
250+
connector.setTenantId("connectorTenantId");
251+
when(mlModel.getConnector()).thenReturn(connector);
252+
when(mlModel.getTenantId()).thenReturn("modelTenantId");
253+
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
254+
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
255+
assertEquals("connectorTenantId", connector.getTenantId());
256+
assertEquals("connectorTenantId", executor.getConnector().getTenantId());
209257
}
210258

211259
private Connector createConnector(Map<String, String> headers) {
@@ -222,7 +270,7 @@ private Connector createConnector(Map<String, String> headers) {
222270
.name("test connector")
223271
.protocol(ConnectorProtocols.HTTP)
224272
.version("1")
225-
.credential(ImmutableMap.of("key", encryptor.encrypt("test_api_key", null)))
273+
.credential(ImmutableMap.of("key", "dummy-encrypted-value"))
226274
.actions(Arrays.asList(predictAction))
227275
.build();
228276
return connector;

0 commit comments

Comments
 (0)