Skip to content

Commit

Permalink
add alibaba cloud inference service
Browse files Browse the repository at this point in the history
  • Loading branch information
weizijun committed Jul 23, 2024
1 parent 78263de commit aa2ad43
Show file tree
Hide file tree
Showing 50 changed files with 4,497 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ENTERPRISE_GEOIP_DOWNLOADER = def(8_708_00_0);
public static final TransportVersion NODES_STATS_ENUM_SET = def(8_709_00_0);
public static final TransportVersion MASTER_NODE_METRICS = def(8_710_00_0);
public static final TransportVersion ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED = def(8_711_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseServiceSettings;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseTaskSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
Expand Down Expand Up @@ -111,6 +118,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addCustomElandWriteables(namedWriteables);
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);

return namedWriteables;
}
Expand Down Expand Up @@ -475,4 +483,58 @@ private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry
)
);
}

private static void addAlibabaCloudSearchNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AlibabaCloudSearchServiceSettings.NAME,
AlibabaCloudSearchServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AlibabaCloudSearchEmbeddingsServiceSettings.NAME,
AlibabaCloudSearchEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AlibabaCloudSearchEmbeddingsTaskSettings.NAME,
AlibabaCloudSearchEmbeddingsTaskSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AlibabaCloudSearchSparseServiceSettings.NAME,
AlibabaCloudSearchSparseServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AlibabaCloudSearchSparseTaskSettings.NAME,
AlibabaCloudSearchSparseTaskSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AlibabaCloudSearchRerankServiceSettings.NAME,
AlibabaCloudSearchRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AlibabaCloudSearchRerankTaskSettings.NAME,
AlibabaCloudSearchRerankTaskSettings::new
)
);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService;
import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
Expand Down Expand Up @@ -221,6 +222,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;

import java.util.Map;
import java.util.Objects;

/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the alibaba cloud search model type.
*/
public class AlibabaCloudSearchActionCreator implements AlibabaCloudSearchActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;

public AlibabaCloudSearchActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings, inputType);

return new AlibabaCloudSearchEmbeddingsAction(sender, overriddenModel, serviceComponents);
}

@Override
public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings, inputType);

return new AlibabaCloudSearchSparseAction(sender, overriddenModel, serviceComponents);
}

@Override
public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = AlibabaCloudSearchRerankModel.of(model, taskSettings);

return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;

import java.util.Map;

public interface AlibabaCloudSearchActionVisitor {
ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;

import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AlibabaCloudSearchEmbeddingsAction implements ExecutableAction {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchEmbeddingsAction.class);

private final AlibabaCloudSearchAccount account;
private final AlibabaCloudSearchEmbeddingsModel model;
private final String failedToSendRequestErrorMessage;
private final Sender sender;
private final AlibabaCloudSearchEmbeddingsRequestManager requestCreator;

public AlibabaCloudSearchEmbeddingsAction(Sender sender, AlibabaCloudSearchEmbeddingsModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.account = new AlibabaCloudSearchAccount(
this.model.getServiceSettings().getCommonSettings().getUri(),
this.model.getSecretSettings().apiKey()
);
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
this.model.getServiceSettings().getCommonSettings().getUri(),
"AlibabaCloud Search text embeddings"
);
this.requestCreator = AlibabaCloudSearchEmbeddingsRequestManager.of(account, model, serviceComponents.threadPool());
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
failedToSendRequestErrorMessage,
listener
);
sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchRerankRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;

import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AlibabaCloudSearchRerankAction implements ExecutableAction {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankAction.class);

private final AlibabaCloudSearchAccount account;
private final AlibabaCloudSearchRerankModel model;
private final String failedToSendRequestErrorMessage;
private final Sender sender;
private final AlibabaCloudSearchRerankRequestManager requestCreator;

public AlibabaCloudSearchRerankAction(Sender sender, AlibabaCloudSearchRerankModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new AlibabaCloudSearchAccount(
this.model.getServiceSettings().getCommonSettings().getUri(),
this.model.getSecretSettings().apiKey()
);
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
this.model.getServiceSettings().getCommonSettings().getUri(),
"AlibabaCloud Search rerank"
);
this.sender = Objects.requireNonNull(sender);
this.requestCreator = AlibabaCloudSearchRerankRequestManager.of(account, model, serviceComponents.threadPool());
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
failedToSendRequestErrorMessage,
listener
);
sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchSparseRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;

import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;

public class AlibabaCloudSearchSparseAction implements ExecutableAction {
private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchSparseAction.class);

private final AlibabaCloudSearchAccount account;
private final AlibabaCloudSearchSparseModel model;
private final String failedToSendRequestErrorMessage;
private final Sender sender;
private final AlibabaCloudSearchSparseRequestManager requestCreator;

public AlibabaCloudSearchSparseAction(Sender sender, AlibabaCloudSearchSparseModel model, ServiceComponents serviceComponents) {
this.model = Objects.requireNonNull(model);
this.account = new AlibabaCloudSearchAccount(
this.model.getServiceSettings().getCommonSettings().getUri(),
this.model.getSecretSettings().apiKey()
);
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
this.model.getServiceSettings().getCommonSettings().getUri(),
"AlibabaCloud Search sparse embeddings"
);
this.sender = Objects.requireNonNull(sender);
requestCreator = AlibabaCloudSearchSparseRequestManager.of(account, model, serviceComponents.threadPool());
}

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
failedToSendRequestErrorMessage,
listener
);
sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
}
Loading

0 comments on commit aa2ad43

Please sign in to comment.