Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add mixed cluster tests for inference #108392

Merged
merged 17 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions TESTING.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -551,13 +551,19 @@ When running `./gradlew check`, minimal bwc checks are also run against compatib

==== BWC Testing against a specific remote/branch

Sometimes a backward compatibility change spans two versions. A common case is a new functionality
that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
To test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of
pulling the release branch from GitHub. You do so using the `bwc.remote` and `bwc.refspec.BRANCH` system properties:
Sometimes a backward compatibility change spans two versions.
A common case is a new functionality that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
Another use case, since the introduction of serverless, is to test BWC against main in addition to the other released branches.
To do so, specify the `bwc.refspec` remote and branch to use for the BWC build as `origin/main`.
To test against main, you will also need to create a new version in link:./server/src/main/java/org/elasticsearch/Version.java[Version.java],
increment `elasticsearch` in link:./build-tools-internal/version.properties[version.properties], and hard-code the `project.version` for ml-cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why we need to change the version in ml cpp?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you increment the elasticsearch version to one that doesn't exist, then the build system will not find the (non-existent) ML-CPP artifact. If, instead, we hardcode the current version of elasticsearch, the existing ML-CPP artifact will be downloaded.

in link:./x-pack/plugin/ml/build.gradle[ml/build.gradle].

In general, to test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of pulling the release branch from GitHub.
You do so using the `bwc.refspec.{VERSION}` system property:

-------------------------------------------------
./gradlew check -Dbwc.remote=${remote} -Dbwc.refspec.5.x=index_req_bwc_5.x
./gradlew check -Dtests.bwc.refspec.8.15=origin/main
-------------------------------------------------

The branch needs to be available on the remote that the BWC makes of the
Expand Down
37 changes: 37 additions & 0 deletions x-pack/plugin/inference/qa/mixed-cluster/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import org.elasticsearch.gradle.Version
import org.elasticsearch.gradle.VersionProperties
import org.elasticsearch.gradle.util.GradleUtils
import org.elasticsearch.gradle.internal.info.BuildParams
import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask

apply plugin: 'elasticsearch.internal-java-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact-base'
apply plugin: 'elasticsearch.bwc-test'

dependencies {
testImplementation project(path: ':x-pack:plugin:inference:qa:inference-service-tests')
compileOnly project(':x-pack:plugin:core')
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
javaRestTestImplementation project(path: xpackModule('inference'))
clusterPlugins project(
':x-pack:plugin:inference:qa:test-service-plugin'
)
}

// inference is available in 8.11 or later
def supportedVersion = bwcVersion -> {
return bwcVersion.onOrAfter(Version.fromString("8.11.0"));
}

BuildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName ->
def javaRestTest = tasks.register("v${bwcVersion}#javaRestTest", StandaloneRestIntegTestTask) {
usesBwcDistribution(bwcVersion)
systemProperty("tests.old_cluster_version", bwcVersion)
maxParallelForks = 1
}

tasks.register(bwcTaskName(bwcVersion)) {
dependsOn javaRestTest
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.qa.mixed;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.List;
import java.util.Map;

public abstract class BaseMixedTestCase extends MixedClusterSpecTestCase {
protected static String getUrl(MockWebServer webServer) {
return Strings.format("http://%s:%s", webServer.getHostName(), webServer.getPort());
}

@Override
protected Settings restClientSettings() {
String token = ESRestTestCase.basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

protected void delete(String inferenceId, TaskType taskType) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, inferenceId));
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
}

protected void delete(String inferenceId) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s", inferenceId));
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
}

protected Map<String, Object> getAll() throws IOException {
var request = new Request("GET", "_inference/_all");
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> get(String inferenceId) throws IOException {
var endpoint = Strings.format("_inference/%s", inferenceId);
var request = new Request("GET", endpoint);
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> get(TaskType taskType, String inferenceId) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("GET", endpoint);
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> inference(String inferenceId, TaskType taskType, String input) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("POST", endpoint);
request.setJsonEntity("{\"input\": [" + '"' + input + '"' + "]}");

var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected Map<String, Object> rerank(String inferenceId, List<String> inputs, String query) throws IOException {
var endpoint = Strings.format("_inference/rerank/%s", inferenceId);
var request = new Request("POST", endpoint);

StringBuilder body = new StringBuilder("{").append("\"query\":\"").append(query).append("\",").append("\"input\":[");

for (int i = 0; i < inputs.size(); i++) {
body.append("\"").append(inputs.get(i)).append("\"");
if (i < inputs.size() - 1) {
body.append(",");
}
}

body.append("]}");
request.setJsonEntity(body.toString());

var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}

protected void put(String inferenceId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, inferenceId);
var request = new Request("PUT", endpoint);
request.setJsonEntity(modelConfig);
var response = ESRestTestCase.client().performRequest(request);
logger.warn("PUT response: {}", response.toString());
System.out.println("PUT response: " + response.toString());
ESRestTestCase.assertOKAndConsume(response);
}

protected static void assertOkOrCreated(Response response) throws IOException {
int statusCode = response.getStatusLine().getStatusCode();
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
// Avoid that call with check here.
if (statusCode == 200 || statusCode == 201) {
return;
}

String responseStr = EntityUtils.toString(response.getEntity());
ESTestCase.assertThat(
responseStr,
response.getStatusLine().getStatusCode(),
Matchers.anyOf(Matchers.equalTo(200), Matchers.equalTo(201))
);
}
}
Loading
Loading