From 059ff8124951712d6373a8146596dfc093464841 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 17 May 2024 13:50:52 +0200 Subject: [PATCH] Extend abstract class for performing tests in semantic_text --- x-pack/plugin/inference/build.gradle | 1 + ...emanticTextNonDynamicFieldMapperTests.java | 87 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index f1542da745e25..f1f1311196435 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -32,6 +32,7 @@ dependencies { compileOnly project(":server") compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) + testImplementation(testArtifact(project(':server'))) testImplementation(project(':x-pack:plugin:inference:qa:test-service-plugin')) testImplementation project(':modules:reindex') clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java new file mode 100644 index 0000000000000..7a1e6d7660e77 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.index.mapper.NonDynamicFieldMapperTests; +import org.elasticsearch.inference.Model; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.Before; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class SemanticTextNonDynamicFieldMapperTests extends NonDynamicFieldMapperTests { + + @Before + public void setup() throws Exception { + Utils.storeSparseModel(client()); + } + + @Override + protected Collection> getPlugins() { + return List.of(Utils.TestInferencePlugin.class); + } + + @Override + protected String getTypeName() { + return SemanticTextFieldMapper.CONTENT_TYPE; + } + + @Override + protected String getMapping() { + return """ + "type": "%s", + "inference_id": "%s" + """.formatted(SemanticTextFieldMapper.CONTENT_TYPE, TestSparseInferenceServiceExtension.TestInferenceService.NAME); + } + + private void storeSparseModel() throws Exception { + Model model = new TestSparseInferenceServiceExtension.TestSparseModel( + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false) + ); + storeModel(model); + } + + private void storeModel(Model model) throws Exception { + ModelRegistry modelRegistry = new ModelRegistry(client()); + + AtomicReference storeModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); + + assertThat(storeModelHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + } + + private void blockingCall(Consumer> function, AtomicReference response, AtomicReference error) + throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + ActionListener listener = ActionListener.wrap(r -> { + response.set(r); + latch.countDown(); + }, e -> { + error.set(e); + latch.countDown(); + }); + + function.accept(listener); + latch.await(); + } +}