Skip to content

Commit

Permalink
fix remote register model / circuit breaker 500 (#2264) (#2273)
Browse files Browse the repository at this point in the history
* move memory CB check into try block to catch exception and hand to listener for register remote model

Signed-off-by: Henry Lindeman <[email protected]>

* add test that memory cb exception is caught by action listener

Signed-off-by: Henry Lindeman <[email protected]>

* unthrow priviledgedExceptionAction

Signed-off-by: Henry Lindeman <[email protected]>

---------

Signed-off-by: Henry Lindeman <[email protected]>
(cherry picked from commit 30642e6)

Co-authored-by: Henry Lindeman <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and HenryL27 authored Mar 25, 2024
1 parent eeba1c3 commit 839785f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ public void registerMLRemoteModel(
MLTask mlTask,
ActionListener<MLRegisterModelResponse> listener
) {
checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.breaker.MemoryCircuitBreaker;
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -112,6 +113,7 @@
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.suppliers.CounterSupplier;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.monitor.jvm.JvmService;
import org.opensearch.script.ScriptService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -449,6 +451,23 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
}

public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class));
String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!";
when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage));

MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);

ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argCaptor.capture());
Exception e = argCaptor.getValue();
assertTrue(e instanceof MLLimitExceededException);
assertEquals(memCBIsOpenMessage, e.getMessage());
}

public void testIndexRemoteModel() throws PrivilegedActionException {
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
Expand Down

0 comments on commit 839785f

Please sign in to comment.