Skip to content

Commit

Permalink
Merge branch 'main' into SignificanceLookup
Browse files Browse the repository at this point in the history
  • Loading branch information
iverase committed Sep 20, 2023
2 parents b2dec82 + cdedf53 commit 31ed031
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@ public static ElserMlNodeModel parseConfig(
Map<String, Object> settings
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(settings, Model.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(settings, Model.TASK_SETTINGS);

var serviceSettings = serviceSettingsFromMap(serviceSettingsMap);

Map<String, Object> taskSettingsMap;
// task settings are optional
if (settings.containsKey(Model.TASK_SETTINGS)) {
taskSettingsMap = removeFromMapOrThrowIfNull(settings, Model.TASK_SETTINGS);
} else {
taskSettingsMap = Map.of();
}

var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap);

if (throwOnUnknownFields == false) {
if (throwOnUnknownFields) {
throwIfNotEmptyMap(settings);
throwIfNotEmptyMap(serviceSettingsMap);
throwIfNotEmptyMap(taskSettingsMap);
Expand Down Expand Up @@ -133,8 +140,6 @@ private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Ma
}

// no config options yet
throwIfNotEmptyMap(config);

return ElserMlNodeTaskSettings.DEFAULT;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@

package org.elasticsearch.xpack.inference.services.elser;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.Model;
import org.elasticsearch.xpack.inference.TaskType;

import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.mock;

public class ElserMlNodeServiceTests extends ESTestCase {

public static Model randomModelConfig(String modelId, TaskType taskType) {
Expand All @@ -25,4 +33,124 @@ public static Model randomModelConfig(String modelId, TaskType taskType) {
default -> throw new IllegalArgumentException("task type " + taskType + " is not supported");
};
}

public void testParseConfigStrict() {
var service = new ElserMlNodeService(mock(Client.class));

var settings = new HashMap<String, Object>();
settings.put(
Model.SERVICE_SETTINGS,
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4))
);
settings.put(Model.TASK_SETTINGS, Map.of());

ElserMlNodeModel parsedModel = service.parseConfigStrict("foo", TaskType.SPARSE_EMBEDDING, settings);

assertEquals(
new ElserMlNodeModel(
"foo",
TaskType.SPARSE_EMBEDDING,
ElserMlNodeService.NAME,
new ElserMlNodeServiceSettings(1, 4),
ElserMlNodeTaskSettings.DEFAULT
),
parsedModel
);
}

public void testParseConfigStrictWithNoTaskSettings() {
var service = new ElserMlNodeService(mock(Client.class));

var settings = new HashMap<String, Object>();
settings.put(
Model.SERVICE_SETTINGS,
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4))
);

ElserMlNodeModel parsedModel = service.parseConfigStrict("foo", TaskType.SPARSE_EMBEDDING, settings);

assertEquals(
new ElserMlNodeModel(
"foo",
TaskType.SPARSE_EMBEDDING,
ElserMlNodeService.NAME,
new ElserMlNodeServiceSettings(1, 4),
ElserMlNodeTaskSettings.DEFAULT
),
parsedModel
);
}

public void testParseConfigStrictWithUnknownSettings() {

for (boolean throwOnUnknown : new boolean[] { true, false }) {
{
var settings = new HashMap<String, Object>();
settings.put(
Model.SERVICE_SETTINGS,
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4))
);
settings.put(Model.TASK_SETTINGS, Map.of());
settings.put("foo", "bar");

if (throwOnUnknown) {
var e = expectThrows(
ElasticsearchStatusException.class,
() -> ElserMlNodeService.parseConfig(throwOnUnknown, "foo", TaskType.SPARSE_EMBEDDING, settings)
);
assertThat(
e.getMessage(),
containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser_mlnode] service")
);
} else {
var parsed = ElserMlNodeService.parseConfig(throwOnUnknown, "foo", TaskType.SPARSE_EMBEDDING, settings);
}
}

{
var settings = new HashMap<String, Object>();
settings.put(
Model.SERVICE_SETTINGS,
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4))
);
settings.put(Model.TASK_SETTINGS, Map.of("foo", "bar"));

if (throwOnUnknown) {
var e = expectThrows(
ElasticsearchStatusException.class,
() -> ElserMlNodeService.parseConfig(throwOnUnknown, "foo", TaskType.SPARSE_EMBEDDING, settings)
);
assertThat(
e.getMessage(),
containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser_mlnode] service")
);
} else {
var parsed = ElserMlNodeService.parseConfig(throwOnUnknown, "foo", TaskType.SPARSE_EMBEDDING, settings);
}
}

{
var settings = new HashMap<String, Object>();
settings.put(
Model.SERVICE_SETTINGS,
new HashMap<>(
Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4, "foo", "bar")
)
);

if (throwOnUnknown) {
var e = expectThrows(
ElasticsearchStatusException.class,
() -> ElserMlNodeService.parseConfig(throwOnUnknown, "foo", TaskType.SPARSE_EMBEDDING, settings)
);
assertThat(
e.getMessage(),
containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser_mlnode] service")
);
} else {
var parsed = ElserMlNodeService.parseConfig(throwOnUnknown, "foo", TaskType.SPARSE_EMBEDDING, settings);
}
}
}
}
}

0 comments on commit 31ed031

Please sign in to comment.