Skip to content

Commit

Permalink
Rest test using the test plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 22, 2023
1 parent 2081517 commit 3b19ef3
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ public class MockInferenceServiceIT extends ESRestTestCase {
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.ml.enabled", "true")
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.security.enabled", "true")
.plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin")
Expand Down Expand Up @@ -111,11 +110,11 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti

var serviceSettings = (Map<String, Object>) getModel.get("service_settings");
assertNull(serviceSettings.get("api_key"));
assertNotNull(serviceSettings.get("model_id"));
assertNotNull(serviceSettings.get("model"));

var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
assertNull(serviceSettings.get("api_key"));
assertNotNull(serviceSettings.get("model_id"));
assertNull(putServiceSettings.get("api_key"));
assertNotNull(putServiceSettings.get("model"));
}

private Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
Expand Down
2 changes: 0 additions & 2 deletions x-pack/plugin/inference/qa/test-service-plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ dependencies {
compileOnly project(':x-pack:plugin:core')
compileOnly project(':x-pack:plugin:inference')
compileOnly project(':x-pack:plugin:ml')
// compileOnly testArtifact(project(xpackModule('core')))
javaRestTestImplementation project(':x-pack:plugin:inference')
}

tasks.named("javaRestTest").configure {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -28,14 +27,11 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.services.MapParsingUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;


Expand Down Expand Up @@ -98,11 +94,12 @@ public TransportVersion getMinimalSupportedVersion() {

public abstract static class TestInferenceServiceBase implements InferenceService {

@SuppressWarnings("unchecked")
private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
Map<String, Object> taskSettingsMap;
// task settings are optional
if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
taskSettingsMap = MapParsingUtils.removeFromMapOrThrowIfNull(settings, ModelConfigurations.TASK_SETTINGS);
taskSettingsMap = (Map<String, Object>) settings.remove(ModelConfigurations.TASK_SETTINGS);
} else {
taskSettingsMap = Map.of();
}
Expand All @@ -115,35 +112,33 @@ public TestInferenceServiceBase(InferenceServicePlugin.InferenceServiceFactoryCo
}

@Override
@SuppressWarnings("unchecked")
public TestServiceModel parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platfromArchitectures
) {
Map<String, Object> serviceSettingsMap = MapParsingUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);

var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);

MapParsingUtils.throwIfNotEmptyMap(config, name());
MapParsingUtils.throwIfNotEmptyMap(serviceSettingsMap, name());
MapParsingUtils.throwIfNotEmptyMap(taskSettingsMap, name());

return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
}

@Override
@SuppressWarnings("unchecked")
public TestServiceModel parsePersistedConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = MapParsingUtils.removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> secretSettingsMap = MapParsingUtils.removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);

var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
Expand All @@ -163,10 +158,9 @@ public void infer(
) {
switch (model.getConfigurations().getTaskType()) {
case SPARSE_EMBEDDING -> {
var results = new ArrayList<TextExpansionResults>();
var results = new ArrayList<TestResults>();
input.forEach(i -> {
int numTokensInResult = Strings.tokenizeToStringArray(i, " ").length;
results.add(createWeightedTokens(numTokensInResult));
results.add(new TestResults("bar"));
});
listener.onResponse(results);
}
Expand Down Expand Up @@ -225,11 +219,11 @@ public record TestServiceSettings(String model) implements ServiceSettings {
public static TestServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String model = MapParsingUtils.removeAsType(map, "model", String.class);
String model = (String) map.remove("model");

if (model == null) {
validationException.addValidationError(
MapParsingUtils.missingSettingErrorMsg("model", ModelConfigurations.SERVICE_SETTINGS)
"missing model"
);
}

Expand Down Expand Up @@ -273,7 +267,7 @@ public record TestTaskSettings(Integer temperature) implements TaskSettings {
private static final String NAME = "test_task_settings";

public static TestTaskSettings fromMap(Map<String, Object> map) {
Integer temperature = MapParsingUtils.removeAsType(map, "temperature", Integer.class);
Integer temperature = (Integer) map.remove("temperature");
return new TestTaskSettings(temperature);
}

Expand Down Expand Up @@ -314,10 +308,10 @@ public record TestSecretSettings(String apiKey) implements SecretSettings {
public static TestSecretSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

String apiKey = MapParsingUtils.removeAsType(map, "api_key", String.class);
String apiKey = (String) map.remove("api_key");

if (apiKey == null) {
validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("api_key", ModelSecrets.SECRET_SETTINGS));
validationException.addValidationError("missing api_key");
}

if (validationException.validationErrors().isEmpty() == false) {
Expand Down Expand Up @@ -355,12 +349,48 @@ public TransportVersion getMinimalSupportedVersion() {
}
}

private static TextExpansionResults createWeightedTokens(int numTokens) {
Random rng = new Random();
List<TextExpansionResults.WeightedToken> tokenList = new ArrayList<>();
for (int i = 0; i < numTokens; i++) {
tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), rng.nextFloat()));
private static class TestResults implements InferenceResults {

private String result;

public TestResults(String result) {
this.result = result;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field("result", result);
return builder;
}

@Override
public String getWriteableName() {
return "test_result";
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(result);
}

@Override
public String getResultsField() {
return "result";
}

@Override
public Map<String, Object> asMap() {
return Map.of("result", result);
}

@Override
public Map<String, Object> asMap(String outputField) {
return Map.of(outputField, result);
}

@Override
public Object predictedValue() {
return result;
}
return new TextExpansionResults("not_used", tokenList, false);
}
}

0 comments on commit 3b19ef3

Please sign in to comment.