diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelDefinitionPartAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelDefinitionPartAction.java index 7bf3bacdf9f48..6afa2fcf7b67f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelDefinitionPartAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelDefinitionPartAction.java @@ -12,20 +12,15 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.block.ClusterBlockException; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -40,7 +35,17 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; -public class TransportPutTrainedModelDefinitionPartAction extends TransportMasterNodeAction { +/** + * The action that allows users to put parts of the model definition. + * + * The action is a {@link HandledTransportAction} as opposed to a {@link org.elasticsearch.action.support.master.TransportMasterNodeAction}. + * This comes with pros and cons. The benefit is that when a model is imported it may spread over hundreds of documents with + * each one being of considerable size. Thus, making this a {@link HandledTransportAction} avoids putting that load on the master node. + * On the downsides, it is care is needed when it comes to adding new fields on those trained model definition docs. The action + * could execute on a node that is on a newer version than the master node. This may mean the native model index does not have + * the mappings required for newly added fields on later versions. + */ +public class TransportPutTrainedModelDefinitionPartAction extends HandledTransportAction { private static final Logger logger = LogManager.getLogger(TransportPutTrainedModelDefinitionPartAction.class); private final TrainedModelProvider trainedModelProvider; @@ -50,32 +55,23 @@ public class TransportPutTrainedModelDefinitionPartAction extends TransportMaste @Inject public TransportPutTrainedModelDefinitionPartAction( TransportService transportService, - ClusterService clusterService, - ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver, Client client, TrainedModelProvider trainedModelProvider ) { - super( - PutTrainedModelDefinitionPartAction.NAME, - transportService, - clusterService, - threadPool, - actionFilters, - Request::new, - indexNameExpressionResolver, - AcknowledgedResponse::readFrom, - ThreadPool.Names.SAME - ); + super(PutTrainedModelDefinitionPartAction.NAME, transportService, actionFilters, Request::new); this.licenseState = licenseState; this.trainedModelProvider = trainedModelProvider; this.client = new OriginSettingClient(client, ML_ORIGIN); } @Override - protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) { + protected void doExecute(Task task, Request request, ActionListener listener) { + if (MachineLearningField.ML_API_FEATURE.check(licenseState) == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } ActionListener configActionListener = ActionListener.wrap(config -> { TrainedModelLocation location = config.getLocation(); @@ -123,19 +119,4 @@ protected void masterOperation(Task task, Request request, ClusterState state, A trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), configActionListener); } - - @Override - protected ClusterBlockException checkBlock(Request request, ClusterState state) { - // TODO do we really need to do this??? - return null; - } - - @Override - protected void doExecute(Task task, Request request, ActionListener listener) { - if (MachineLearningField.ML_API_FEATURE.check(licenseState)) { - super.doExecute(task, request, listener); - } else { - listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); - } - } }