Skip to content

Commit

Permalink
fixing agent execution for multi-tenancy (#2792)
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Aug 1, 2024
1 parent 1d274fd commit e94acc6
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.settings;

public interface SettingsChangeListener {
void onMultiTenancyEnabledChanged(boolean isEnabled);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,5 @@ public interface Executable {
* @param input input data
* @return execution result
*/
default void execute(Input input, ActionListener<Output> listener) throws ExecuteException {
execute(input, null, listener);
}

/**
* Execute algorithm with given input data.
* @param input input data
* @param tenantId id of the tenant for multi-tenancy.
* For single tenant, it will be null
* @return execution result
*/
void execute(Input input, String tenantId, ActionListener<Output> listener) throws ExecuteException;
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public abstract class DLModelExecute implements MLExecutable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

public abstract void execute(Input input, String tenantId, ActionListener<Output> listener);
public abstract void execute(Input input, ActionListener<Output> listener);

protected Predictor<float[][], ai.djl.modality.Output> getPredictor() {
int currentDevice = nextDevice.getAndIncrement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.settings.SettingsChangeListener;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.Executable;
Expand All @@ -66,7 +67,7 @@
@Data
@NoArgsConstructor
@Function(FunctionName.AGENT)
public class MLAgentExecutor implements Executable {
public class MLAgentExecutor implements Executable, SettingsChangeListener {

public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
Expand Down Expand Up @@ -104,20 +105,25 @@ public MLAgentExecutor(
}

@Override
public void execute(Input input, String tenantId, ActionListener<Output> listener) {
public void onMultiTenancyEnabledChanged(boolean isEnabled) {
this.isMultiTenancyEnabled = isEnabled;
}

@Override
public void execute(Input input, ActionListener<Output> listener) {
if (!(input instanceof AgentMLInput)) {
throw new IllegalArgumentException("wrong input");
}
AgentMLInput agentMLInput = (AgentMLInput) input;
String agentId = agentMLInput.getAgentId();
String agentTenantId = agentMLInput.getTenantId();
String tenantId = agentMLInput.getTenantId();

RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset();
if (inputDataSet == null || inputDataSet.getParameters() == null) {
throw new IllegalArgumentException("Agent input data can not be empty.");
}

if (isMultiTenancyEnabled && !Objects.equals(tenantId, agentTenantId)) {
if (isMultiTenancyEnabled && tenantId == null) {
throw new OpenSearchStatusException("You don't have permission to access this resource", RestStatus.FORBIDDEN);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ protected List<Map.Entry<Long, Long>> getAllIntervals() {
}

@Override
public void execute(Input input, String tenantId, ActionListener<Output> listener) {
public void execute(Input input, ActionListener<Output> listener) {
getLocalizationResults(
(AnomalyLocalizationInput) input,
ActionListener.wrap(o -> listener.onResponse(o), e -> listener.onFailure(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust
* algorithm is a list of objects. Each object contains 3 properties event_window, event_pattern and suspected_metrics
*/
@Override
public void execute(Input input, String tenantId, ActionListener<org.opensearch.ml.common.output.Output> listener) {
public void execute(Input input, ActionListener<org.opensearch.ml.common.output.Output> listener) {
if (!(input instanceof MetricsCorrelationInput)) {
throw new ExecuteException("wrong input");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public LocalSampleCalculator(Client client, Settings settings) {
}

@Override
public void execute(Input input, String tenantId, ActionListener<Output> listener) {
public void execute(Input input, ActionListener<Output> listener) {
if (!(input instanceof LocalSampleCalculatorInput)) {
throw new IllegalArgumentException("wrong input");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ public Collection<Object> createComponents(
memoryFactoryMap,
mlFeatureEnabledSetting.isMultiTenancyEnabled()
);
// Register the agentExecutor as a listener
mlFeatureEnabledSetting.addListener(agentExecutor);

MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator);
MLEngineClassLoader.register(FunctionName.AGENT, agentExecutor);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import java.util.ArrayList;
import java.util.List;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.settings.SettingsChangeListener;

import com.google.common.annotations.VisibleForTesting;

public class MLFeatureEnabledSetting {

Expand All @@ -25,6 +31,8 @@ public class MLFeatureEnabledSetting {
// This is to identify if this node is in multi-tenancy or not.
private volatile Boolean isMultiTenancyEnabled;

private final List<SettingsChangeListener> listeners = new ArrayList<>();

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
Expand All @@ -38,7 +46,10 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MULTI_TENANCY_ENABLED, it -> isMultiTenancyEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MULTI_TENANCY_ENABLED, it -> {
isMultiTenancyEnabled = it;
notifyMultiTenancyListeners(it);
});
}

/**
Expand Down Expand Up @@ -73,4 +84,15 @@ public boolean isMultiTenancyEnabled() {
return isMultiTenancyEnabled;
}

public void addListener(SettingsChangeListener listener) {
listeners.add(listener);
}

@VisibleForTesting
void notifyMultiTenancyListeners(boolean isEnabled) {
for (SettingsChangeListener listener : listeners) {
listener.onMultiTenancyEnabledChanged(isEnabled);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.settings;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import java.util.Set;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.settings.SettingsChangeListener;

public class MLFeatureEnabledSettingTests {
@Mock
private ClusterService clusterService;
private Settings settings;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
private SettingsChangeListener listener;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build();
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings())
.thenReturn(
new ClusterSettings(
settings,
Set
.of(
ML_COMMONS_MULTI_TENANCY_ENABLED,
ML_COMMONS_REMOTE_INFERENCE_ENABLED,
ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
ML_COMMONS_LOCAL_MODEL_ENABLED
)
)
);

mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings);
listener = mock(SettingsChangeListener.class);
}

@Test
public void testAddListenerAndNotify() {
mlFeatureEnabledSetting.addListener(listener);

// Simulate settings change
mlFeatureEnabledSetting.notifyMultiTenancyListeners(false);

// Verify listener is notified
verify(listener, times(1)).onMultiTenancyEnabledChanged(false);
}
}

0 comments on commit e94acc6

Please sign in to comment.