Skip to content

Commit

Permalink
Include changes from main in ModelRegistry
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 1, 2024
1 parent 3c3169d commit 3c9c32f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ public interface ModelRegistry {
void deleteModel(String modelId, ActionListener<Boolean> listener);

/**
* Semi parsed model where model id, task type and service
* Semi parsed model where inference entity id, task type and service
* are known but the settings are not parsed.
*/
record UnparsedModel(String modelId, TaskType taskType, String service, Map<String, Object> settings, Map<String, Object> secrets) {}
record UnparsedModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> settings,
Map<String, Object> secrets
) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettingsTests;
Expand Down Expand Up @@ -54,13 +54,13 @@
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;

public class ModelRegistryIT extends ESSingleNodeTestCase {
public class ModelRegistryImplIT extends ESSingleNodeTestCase {

private ModelRegistry modelRegistry;
private ModelRegistryImpl ModelRegistryImpl;

@Before
public void createComponents() {
modelRegistry = new ModelRegistry(client());
ModelRegistryImpl = new ModelRegistryImpl(client());
}

@Override
Expand All @@ -74,7 +74,7 @@ public void testStoreModel() throws Exception {
AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder);

assertThat(storeModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
Expand All @@ -86,7 +86,7 @@ public void testStoreModelWithUnknownFields() throws Exception {
AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder);

assertNull(storeModelHolder.get());
assertNotNull(exceptionHolder.get());
Expand All @@ -105,12 +105,12 @@ public void testGetModel() throws Exception {
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));

// now get the model
AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
AtomicReference<ModelRegistryImpl.UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(nullValue()));
assertThat(modelHolder.get(), not(nullValue()));

Expand All @@ -132,13 +132,13 @@ public void testStoreModelFailsWhenModelExists() throws Exception {
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));

putModelHolder.set(false);
// an model with the same id exists
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(false));
assertThat(exceptionHolder.get(), not(nullValue()));
assertThat(
Expand All @@ -153,20 +153,20 @@ public void testDeleteModel() throws Exception {
Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING);
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
}

AtomicReference<Boolean> deleteResponseHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(nullValue()));
assertTrue(deleteResponseHolder.get());

// get should fail
deleteResponseHolder.set(false);
AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder);
AtomicReference<ModelRegistryImpl.UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder);

assertThat(exceptionHolder.get(), not(nullValue()));
assertFalse(deleteResponseHolder.get());
Expand All @@ -186,13 +186,13 @@ public void testGetModelsByTaskType() throws InterruptedException {
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
}

AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
AtomicReference<List<ModelRegistry.UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
AtomicReference<List<ModelRegistryImpl.UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get(), hasSize(3));
var sparseIds = sparseAndTextEmbeddingModels.stream()
.filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING)
Expand All @@ -203,7 +203,7 @@ public void testGetModelsByTaskType() throws InterruptedException {
assertThat(m.secrets().keySet(), empty());
});

blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get(), hasSize(2));
var denseIds = sparseAndTextEmbeddingModels.stream()
.filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING)
Expand All @@ -227,13 +227,13 @@ public void testGetAllModels() throws InterruptedException {
var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service);
createdModels.add(model);

blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());
}

AtomicReference<List<ModelRegistry.UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
AtomicReference<List<ModelRegistryImpl.UnparsedModel>> modelHolder = new AtomicReference<>();
blockingCall(listener -> ModelRegistryImpl.getAllModels(listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get(), hasSize(modelCount));
var getAllModels = modelHolder.get();

Expand All @@ -257,18 +257,18 @@ public void testGetModelWithSecrets() throws InterruptedException {
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret);
blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder);
assertThat(putModelHolder.get(), is(true));
assertNull(exceptionHolder.get());

AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
AtomicReference<ModelRegistryImpl.UnparsedModel> modelHolder = new AtomicReference<>();
blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get().secrets().keySet(), hasSize(1));
var secretSettings = (Map<String, Object>) modelHolder.get().secrets().get("secret_settings");
assertThat(secretSettings.get("secret"), equalTo(secret));

// get model without secrets
blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder);
blockingCall(listener -> ModelRegistryImpl.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder);
assertThat(modelHolder.get().secrets().keySet(), empty());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class ModelRegistryTests extends ESTestCase {
public class ModelRegistryImplTests extends ESTestCase {

private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);

Expand Down

0 comments on commit 3c9c32f

Please sign in to comment.