6
6
package org .opensearch .ml .engine .algorithms .remote ;
7
7
8
8
import static org .junit .Assert .assertEquals ;
9
+ import static org .junit .Assert .assertNull ;
9
10
import static org .mockito .ArgumentMatchers .any ;
10
11
import static org .mockito .Mockito .doThrow ;
11
12
import static org .mockito .Mockito .mock ;
12
- import static org .mockito .Mockito .spy ;
13
13
import static org .mockito .Mockito .verify ;
14
14
import static org .mockito .Mockito .when ;
15
15
39
39
import org .opensearch .ml .engine .MLEngineClassLoader ;
40
40
import org .opensearch .ml .engine .MLStaticMockBase ;
41
41
import org .opensearch .ml .engine .encryptor .Encryptor ;
42
- import org .opensearch .ml .engine .encryptor .EncryptorImpl ;
43
42
44
43
import com .google .common .collect .ImmutableMap ;
45
44
@@ -64,7 +63,9 @@ public class RemoteModelTest extends MLStaticMockBase {
64
63
public void setUp () {
65
64
MockitoAnnotations .openMocks (this );
66
65
remoteModel = new RemoteModel ();
67
- encryptor = spy (new EncryptorImpl (null , "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=" ));
66
+
67
+ encryptor = mock (Encryptor .class );
68
+ when (encryptor .decrypt (any (), any ())).thenReturn ("test_api_key" );
68
69
}
69
70
70
71
@ Test
@@ -189,7 +190,7 @@ public void initModel_NullHeader() {
189
190
when (mlModel .getConnector ()).thenReturn (connector );
190
191
remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
191
192
Map <String , String > decryptedHeaders = connector .getDecryptedHeaders ();
192
- Assert . assertNull (decryptedHeaders );
193
+ assertNull (decryptedHeaders );
193
194
}
194
195
195
196
@ Test
@@ -200,12 +201,59 @@ public void initModel_WithHeader() {
200
201
Map <String , String > decryptedHeaders = connector .getDecryptedHeaders ();
201
202
RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
202
203
Assert .assertNotNull (executor );
203
- Assert . assertNull (decryptedHeaders );
204
+ assertNull (decryptedHeaders );
204
205
Assert .assertNotNull (executor .getConnector ().getDecryptedHeaders ());
205
206
assertEquals (1 , executor .getConnector ().getDecryptedHeaders ().size ());
206
207
assertEquals ("Bearer test_api_key" , executor .getConnector ().getDecryptedHeaders ().get ("Authorization" ));
207
208
remoteModel .close ();
208
- Assert .assertNull (remoteModel .getConnectorExecutor ());
209
+ assertNull (remoteModel .getConnectorExecutor ());
210
+ }
211
+
212
+ @ Test
213
+ public void initModel_setsTenantIdOnClonedConnector_whenMissing () {
214
+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
215
+ when (mlModel .getConnector ()).thenReturn (connector );
216
+ when (mlModel .getTenantId ()).thenReturn ("tenantId" );
217
+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
218
+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
219
+ remoteModel .close ();
220
+ assertNull (connector .getTenantId ());
221
+ assertEquals ("tenantId" , executor .getConnector ().getTenantId ());
222
+ }
223
+
224
+ @ Test
225
+ public void initModel_bothTenantIdsNull () {
226
+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
227
+ when (mlModel .getConnector ()).thenReturn (connector );
228
+ when (mlModel .getTenantId ()).thenReturn (null );
229
+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
230
+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
231
+ assertNull (connector .getTenantId ());
232
+ assertNull (executor .getConnector ().getTenantId ());
233
+ }
234
+
235
+ @ Test
236
+ public void initModel_connectorHasTenantId () {
237
+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
238
+ connector .setTenantId ("connectorTenantId" );
239
+ when (mlModel .getConnector ()).thenReturn (connector );
240
+ when (mlModel .getTenantId ()).thenReturn (null );
241
+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
242
+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
243
+ assertEquals ("connectorTenantId" , connector .getTenantId ());
244
+ assertEquals ("connectorTenantId" , executor .getConnector ().getTenantId ());
245
+ }
246
+
247
+ @ Test
248
+ public void initModel_bothHaveTenantIds () {
249
+ Connector connector = createConnector (ImmutableMap .of ("Authorization" , "Bearer ${credential.key}" ));
250
+ connector .setTenantId ("connectorTenantId" );
251
+ when (mlModel .getConnector ()).thenReturn (connector );
252
+ when (mlModel .getTenantId ()).thenReturn ("modelTenantId" );
253
+ remoteModel .initModel (mlModel , ImmutableMap .of (), encryptor );
254
+ RemoteConnectorExecutor executor = remoteModel .getConnectorExecutor ();
255
+ assertEquals ("connectorTenantId" , connector .getTenantId ());
256
+ assertEquals ("connectorTenantId" , executor .getConnector ().getTenantId ());
209
257
}
210
258
211
259
private Connector createConnector (Map <String , String > headers ) {
@@ -222,7 +270,7 @@ private Connector createConnector(Map<String, String> headers) {
222
270
.name ("test connector" )
223
271
.protocol (ConnectorProtocols .HTTP )
224
272
.version ("1" )
225
- .credential (ImmutableMap .of ("key" , encryptor . encrypt ( "test_api_key" , null ) ))
273
+ .credential (ImmutableMap .of ("key" , "dummy-encrypted-value" ))
226
274
.actions (Arrays .asList (predictAction ))
227
275
.build ();
228
276
return connector ;
0 commit comments