Skip to content

Commit

Permalink
[8.14] [ML] External inference service rolling upgrade tests (elastic…
Browse files Browse the repository at this point in the history
…#107619) (elastic#107868)

* [ML] External inference service rolling upgrade tests (elastic#107619)

Rolling upgrade tests for OpenAI, Cohere and Hugging Face. 
The services are tested by setting the services' url to a local mock
 web server and mocking the responses.

* fix compilation

* Reduce inference rolling upgrade test parallelism

---------

Co-authored-by: Elastic Machine <[email protected]>
Co-authored-by: Mark Vieira <[email protected]>
  • Loading branch information
3 people authored Apr 25, 2024
1 parent cf5588a commit 4259c37
Show file tree
Hide file tree
Showing 11 changed files with 1,137 additions and 39 deletions.
34 changes: 34 additions & 0 deletions x-pack/plugin/inference/qa/rolling-upgrade/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

import org.elasticsearch.gradle.Version
import org.elasticsearch.gradle.VersionProperties
import org.elasticsearch.gradle.internal.info.BuildParams
import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask

apply plugin: 'elasticsearch.internal-java-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact-base'
apply plugin: 'elasticsearch.bwc-test'


dependencies {
compileOnly project(':x-pack:plugin:core')
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
javaRestTestImplementation project(path: xpackModule('inference'))
javaRestTestImplementation(testArtifact(project(":qa:rolling-upgrade"), "javaRestTest"))
}

// Inference API added in 8.11
BuildParams.bwcVersions.withWireCompatible(v -> v.after("8.11.0")) { bwcVersion, baseName ->
tasks.register(bwcTaskName(bwcVersion), StandaloneRestIntegTestTask) {
usesBwcDistribution(bwcVersion)
systemProperty("tests.old_cluster_version", bwcVersion)
maxParallelForks = 1
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.application;

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.not;

public class AzureOpenAiServiceUpgradeIT extends InferenceUpgradeTestCase {

private static final String OPEN_AI_AZURE_EMBEDDINGS_ADDED = "8.14.0";

private static MockWebServer openAiEmbeddingsServer;

public AzureOpenAiServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}

@BeforeClass
public static void startWebServer() throws IOException {
openAiEmbeddingsServer = new MockWebServer();
openAiEmbeddingsServer.start();
}

@AfterClass
public static void shutdown() {
openAiEmbeddingsServer.close();
}

@SuppressWarnings("unchecked")
@AwaitsFix(bugUrl = "Cannot set the URL in the tests")
public void testOpenAiEmbeddings() throws IOException {
var openAiEmbeddingsSupported = getOldClusterTestVersion().onOrAfter(OPEN_AI_AZURE_EMBEDDINGS_ADDED);
assumeTrue("Azure OpenAI embedding service added in " + OPEN_AI_AZURE_EMBEDDINGS_ADDED, openAiEmbeddingsSupported);

final String oldClusterId = "old-cluster-embeddings";
final String upgradedClusterId = "upgraded-cluster-embeddings";

if (isOldCluster()) {
// queue a response as PUT will call the service
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(OpenAiServiceUpgradeIT.embeddingResponse()));
put(oldClusterId, embeddingConfig(getUrl(openAiEmbeddingsServer)), TaskType.TEXT_EMBEDDING);

var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("models");
assertThat(configs, hasSize(1));
} else if (isMixedCluster()) {
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("models");
assertEquals("azureopenai", configs.get(0).get("service"));

assertEmbeddingInference(oldClusterId);
} else if (isUpgradedCluster()) {
// check old cluster model
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("models");
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");

// Inference on old cluster model
assertEmbeddingInference(oldClusterId);

openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(OpenAiServiceUpgradeIT.embeddingResponse()));
put(upgradedClusterId, embeddingConfig(getUrl(openAiEmbeddingsServer)), TaskType.TEXT_EMBEDDING);

configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, upgradedClusterId).get("models");
assertThat(configs, hasSize(1));

// Inference on the new config
assertEmbeddingInference(upgradedClusterId);

delete(oldClusterId);
delete(upgradedClusterId);
}
}

void assertEmbeddingInference(String inferenceId) throws IOException {
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(OpenAiServiceUpgradeIT.embeddingResponse()));
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
}

private String embeddingConfig(String url) {
return Strings.format("""
{
"service": "azureopenai",
"service_settings": {
"api_key": "XXXX",
"url": "%s",
"resource_name": "resource_name",
"deployment_id": "deployment_id",
"api_version": "2024-02-01"
}
}
""", url);
}

}
Loading

0 comments on commit 4259c37

Please sign in to comment.