Skip to content

Commit 747b4d6

Browse files
committed
excluding circuit breaker for Agent
Signed-off-by: Dhrubo Saha <[email protected]>
1 parent 913b033 commit 747b4d6

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,10 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
975975
* @param runningTaskLimit limit
976976
*/
977977
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
978-
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
978+
979+
// for agent and remote model prediction we don't need to check circuit breaker
980+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE
981+
&& mlTask.getFunctionName() != FunctionName.AGENT) {
979982
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
980983
}
981984
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ public void dispatchTask(
112112
if (clusterService.localNode().getId().equals(nodeId)) {
113113
// Execute ML task locally
114114
log.debug("Execute ML request {} locally on node {}", request.getRequestID(), nodeId);
115-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
116-
executeTask(request, listener);
115+
checkCBAndExecute(functionName, request, listener);
117116
} else {
118117
// Execute ML task remotely
119118
log.debug("Execute ML request {} remotely on node {}", request.getRequestID(), nodeId);
@@ -130,7 +129,8 @@ public void dispatchTask(
130129
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131130

132131
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133-
if (functionName != FunctionName.REMOTE) {
132+
// for agent and remote model prediction we don't need to check circuit breaker
133+
if (functionName != FunctionName.REMOTE && functionName != FunctionName.AGENT) {
134134
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135135
}
136136
executeTask(request, listener);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+12
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,18 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
361361
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
362362
}
363363

364+
public void testRegisterMLModel_CircuitBreakerNotOpenForAgent() {
365+
registerModelInput.setFunctionName(FunctionName.AGENT);
366+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
367+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
368+
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
369+
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
370+
expectedEx.expect(CircuitBreakingException.class);
371+
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
372+
modelManager.registerMLModel(registerModelInput, mlTask);
373+
verify(mlTaskManager).updateMLTask(anyString(), any(), anyMap(), anyLong(), anyBoolean());
374+
}
375+
364376
public void testRegisterMLModel_InitModelIndexFailure() {
365377
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
366378
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

+12
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,16 @@ public void testRun_NoCircuitbreakerforRemote() {
155155
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
156156
assertEquals(0L, value.longValue());
157157
}
158+
159+
public void testRun_NoCircuitbreakerforAgent() {
160+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
161+
when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker");
162+
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
163+
TransportService transportService = mock(TransportService.class);
164+
ActionListener listener = mock(ActionListener.class);
165+
MLTaskRequest request = new MLTaskRequest(false);
166+
mlTaskRunner.run(FunctionName.AGENT, request, transportService, listener);
167+
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
168+
assertEquals(0L, value.longValue());
169+
}
158170
}

0 commit comments

Comments
 (0)