diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index 65027653..b6f9164b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -38,7 +38,6 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; -import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; import com.amazon.opendistroforelasticsearch.ad.rest.RestAnomalyDetectorJobAction; import com.amazon.opendistroforelasticsearch.ad.rest.RestDeleteAnomalyDetectorAction; import com.amazon.opendistroforelasticsearch.ad.rest.RestExecuteAnomalyDetectorAction; @@ -67,6 +66,8 @@ import com.amazon.opendistroforelasticsearch.ad.transport.DeleteDetectorTransportAction; import com.amazon.opendistroforelasticsearch.ad.transport.DeleteModelAction; import com.amazon.opendistroforelasticsearch.ad.transport.DeleteModelTransportAction; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileAction; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileTransportAction; import com.amazon.opendistroforelasticsearch.ad.transport.RCFResultAction; import com.amazon.opendistroforelasticsearch.ad.transport.RCFResultTransportAction; import com.amazon.opendistroforelasticsearch.ad.transport.StopDetectorAction; @@ -185,12 +186,8 @@ public List getRestHandlers( jobRunner.setAnomalyResultHandler(anomalyResultHandler); jobRunner.setSettings(settings); - AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner(client, this.xContentRegistry); - RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction( - restController, - profileRunner, - ProfileName.getNames() - ); + AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner(client, this.xContentRegistry, this.nodeFilter); + RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(restController, profileRunner); RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction( settings, restController, @@ -438,7 +435,8 @@ public List getNamedWriteables() { new ActionHandler<>(ThresholdResultAction.INSTANCE, ThresholdResultTransportAction.class), new ActionHandler<>(AnomalyResultAction.INSTANCE, AnomalyResultTransportAction.class), new ActionHandler<>(CronAction.INSTANCE, CronTransportAction.class), - new ActionHandler<>(ADStatsNodesAction.INSTANCE, ADStatsNodesTransportAction.class) + new ActionHandler<>(ADStatsNodesAction.INSTANCE, ADStatsNodesTransportAction.class), + new ActionHandler<>(ProfileAction.INSTANCE, ProfileTransportAction.class) ); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java index 1cf75cda..b254eb3a 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunner.java @@ -30,6 +30,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParseException; @@ -44,23 +45,30 @@ import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.SortOrder; +import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob; import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; import com.amazon.opendistroforelasticsearch.ad.model.DetectorProfile; import com.amazon.opendistroforelasticsearch.ad.model.DetectorState; import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileAction; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileRequest; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileResponse; +import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; import com.amazon.opendistroforelasticsearch.ad.util.MultiResponsesDelegateActionListener; public class AnomalyDetectorProfileRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); private Client client; private NamedXContentRegistry xContentRegistry; + private DiscoveryNodeFilterer nodeFilter; static String FAIL_TO_FIND_DETECTOR_MSG = "Fail to find detector with id: "; static String FAIL_TO_GET_PROFILE_MSG = "Fail to get profile for detector "; - public AnomalyDetectorProfileRunner(Client client, NamedXContentRegistry xContentRegistry) { + public AnomalyDetectorProfileRunner(Client client, NamedXContentRegistry xContentRegistry, DiscoveryNodeFilterer nodeFilter) { this.client = client; this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; } public void profile(String detectorId, ActionListener listener, Set profiles) { @@ -70,9 +78,28 @@ public void profile(String detectorId, ActionListener listener, return; } + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide when to consolidate results + // and return to users + int totalListener = 0; + + if (profiles.contains(ProfileName.STATE)) { + totalListener++; + } + + if (profiles.contains(ProfileName.ERROR)) { + totalListener++; + } + + if (profiles.contains(ProfileName.COORDINATING_NODE) + || profiles.contains(ProfileName.SHINGLE_SIZE) + || profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profiles.contains(ProfileName.MODELS)) { + totalListener++; + } + MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener( listener, - profiles.size(), + totalListener, "Fail to fetch profile for " + detectorId ); @@ -102,6 +129,13 @@ private void prepareProfile( if (profiles.contains(ProfileName.ERROR)) { profileError(detectorId, enabledTimeMs, listener); } + + if (profiles.contains(ProfileName.COORDINATING_NODE) + || profiles.contains(ProfileName.SHINGLE_SIZE) + || profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profiles.contains(ProfileName.MODELS)) { + profileModels(detectorId, profiles, listener); + } } catch (IOException | XContentParseException | NullPointerException e) { logger.error(e); listener.failImmediately(FAIL_TO_GET_PROFILE_MSG, e); @@ -280,8 +314,42 @@ private SearchRequest createLatestAnomalyResultRequest(String detectorId, long e SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1).sort(sortQuery); - SearchRequest request = new SearchRequest(AnomalyResult.ANOMALY_RESULT_INDEX); + SearchRequest request = new SearchRequest(AnomalyDetectionIndices.ALL_AD_RESULTS_INDEX_PATTERN); request.source(source); return request; } + + private void profileModels( + String detectorId, + Set profiles, + MultiResponsesDelegateActionListener listener + ) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, dataNodes); + client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detectorId, profiles, listener)); + } + + private ActionListener onModelResponse( + String detectorId, + Set profiles, + MultiResponsesDelegateActionListener listener + ) { + return ActionListener.wrap(profileResponse -> { + DetectorProfile profile = new DetectorProfile(); + if (profiles.contains(ProfileName.COORDINATING_NODE)) { + profile.setCoordinatingNode(profileResponse.getCoordinatingNode()); + } + if (profiles.contains(ProfileName.SHINGLE_SIZE)) { + profile.setShingleSize(profileResponse.getShingleSize()); + } + if (profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES)) { + profile.setTotalSizeInBytes(profileResponse.getTotalSizeInBytes()); + } + if (profiles.contains(ProfileName.MODELS)) { + profile.setModelProfile(profileResponse.getModelProfile()); + } + + listener.onResponse(profile); + }, listener::onFailure); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java index c378dc1d..bed9f174 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/constant/CommonName.java @@ -44,4 +44,14 @@ public class CommonName { // box type public static final String BOX_TYPE_KEY = "box_type"; + + // ====================================== + // Profile name + // ====================================== + public static final String STATE = "state"; + public static final String ERROR = "error"; + public static final String COORDINATING_NODE = "coordinating_node"; + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; + public static final String MODELS = "models"; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java index 6af6398b..07e160bc 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/feature/FeatureManager.java @@ -536,4 +536,13 @@ private double[][] transpose(double[][] matrix) { private long truncateToMinute(long epochMillis) { return Instant.ofEpochMilli(epochMillis).truncatedTo(ChronoUnit.MINUTES).toEpochMilli(); } + + public int getShingleSize(String detectorId) { + Deque> shingle = detectorIdsToTimeShingles.get(detectorId); + if (shingle != null) { + return shingle.size(); + } else { + return -1; + } + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index ea8f08db..cebd2498 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.AbstractMap.SimpleImmutableEntry; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -890,4 +891,24 @@ private double computeRcfConfidence(RandomCutForest forest) { return Math.max(0, confidence); // Replaces -0 wth 0 for cosmetic purpose. } } + + /** + * Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB). + * @param detectorId detector id + * @return a map of model id to its memory size + */ + public Map getModelSize(String detectorId) { + Map res = new HashMap<>(); + forests + .entrySet() + .stream() + .filter(entry -> getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .forEach(entry -> { res.put(entry.getKey(), estimateModelSize(entry.getValue().getModel())); }); + thresholds + .entrySet() + .stream() + .filter(entry -> getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .forEach(entry -> { res.put(entry.getKey(), 0L); }); + return res; + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java index 30650cbe..41e58c61 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/DetectorProfile.java @@ -24,27 +24,56 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; + public class DetectorProfile implements ToXContentObject, Mergeable { private DetectorState state; private String error; - - private static final String STATE_FIELD = "state"; - private static final String ERROR_FIELD = "error"; + private ModelProfile[] modelProfile; + private int shingleSize; + private String coordinatingNode; + private long totalSizeInBytes; public XContentBuilder toXContent(XContentBuilder builder) throws IOException { return toXContent(builder, ToXContent.EMPTY_PARAMS); } + public DetectorProfile() { + state = null; + error = null; + modelProfile = null; + shingleSize = -1; + coordinatingNode = null; + totalSizeInBytes = -1; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); if (state != null) { - xContentBuilder.field(STATE_FIELD, state); + xContentBuilder.field(CommonName.STATE, state); } if (error != null) { - xContentBuilder.field(ERROR_FIELD, error); + xContentBuilder.field(CommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + xContentBuilder.startArray(CommonName.MODELS); + for (ModelProfile profile : modelProfile) { + profile.toXContent(xContentBuilder, params); + } + xContentBuilder.endArray(); + } + if (shingleSize != -1) { + xContentBuilder.field(CommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null) { + xContentBuilder.field(CommonName.COORDINATING_NODE, coordinatingNode); } + if (totalSizeInBytes != -1) { + xContentBuilder.field(CommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + return xContentBuilder.endObject(); } @@ -64,6 +93,38 @@ public void setError(String error) { this.error = error; } + public ModelProfile[] getModelProfile() { + return modelProfile; + } + + public void setModelProfile(ModelProfile[] modelProfile) { + this.modelProfile = modelProfile; + } + + public int getShingleSize() { + return shingleSize; + } + + public void setShingleSize(int shingleSize) { + this.shingleSize = shingleSize; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public void setCoordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + } + + public long getTotalSizeInBytes() { + return totalSizeInBytes; + } + + public void setTotalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + } + @Override public void merge(Mergeable other) { if (this == other || other == null || getClass() != other.getClass()) { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ModelProfile.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ModelProfile.java new file mode 100644 index 00000000..94f47f83 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ModelProfile.java @@ -0,0 +1,79 @@ +package com.amazon.opendistroforelasticsearch.ad.model; + +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +import java.io.IOException; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +public class ModelProfile implements Writeable, ToXContent { + // field name in toXContent + public static final String MODEL_ID = "model_id"; + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + public static final String NODE_ID = "node_id"; + + private final String modelId; + private final long modelSizeInBytes; + private final String nodeId; + + public ModelProfile(String modelId, long modelSize, String nodeId) { + super(); + this.modelId = modelId; + this.modelSizeInBytes = modelSize; + this.nodeId = nodeId; + } + + public ModelProfile(StreamInput in) throws IOException { + modelId = in.readString(); + modelSizeInBytes = in.readVLong(); + nodeId = in.readString(); + } + + public String getModelId() { + return modelId; + } + + public long getModelSize() { + return modelSizeInBytes; + } + + public String getNodeId() { + return nodeId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID, modelId); + if (modelSizeInBytes > 0) { + builder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + builder.field(NODE_ID, nodeId); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeVLong(modelSizeInBytes); + out.writeString(nodeId); + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java index ea0be275..3c3fa93b 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/model/ProfileName.java @@ -19,9 +19,15 @@ import java.util.HashSet; import java.util.Set; +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; + public enum ProfileName { - STATE("state"), - ERROR("error"); + STATE(CommonName.STATE), + ERROR(CommonName.ERROR), + COORDINATING_NODE(CommonName.COORDINATING_NODE), + SHINGLE_SIZE(CommonName.SHINGLE_SIZE), + TOTAL_SIZE_IN_BYTES(CommonName.TOTAL_SIZE_IN_BYTES), + MODELS(CommonName.MODELS); private String name; @@ -38,26 +44,20 @@ public String getName() { return name; } - /** - * Get set of profile names - * - * @return set of profile names - */ - public static Set getNames() { - Set names = new HashSet<>(); - - for (ProfileName statName : ProfileName.values()) { - names.add(statName.getName()); - } - return names; - } - public static ProfileName getName(String name) { switch (name) { - case "state": + case CommonName.STATE: return STATE; - case "error": + case CommonName.ERROR: return ERROR; + case CommonName.COORDINATING_NODE: + return COORDINATING_NODE; + case CommonName.SHINGLE_SIZE: + return SHINGLE_SIZE; + case CommonName.TOTAL_SIZE_IN_BYTES: + return TOTAL_SIZE_IN_BYTES; + case CommonName.MODELS: + return MODELS; default: throw new IllegalArgumentException("Unsupported profile types"); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java index b52b6966..f45dbc52 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -50,8 +50,10 @@ import java.io.IOException; import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Locale; import java.util.Set; +import java.util.stream.Collectors; import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; @@ -71,15 +73,17 @@ public class RestGetAnomalyDetectorAction extends BaseRestHandler { private final AnomalyDetectorProfileRunner profileRunner; private final Set allProfileTypeStrs; private final Set allProfileTypes; + private final Set defaultProfileTypes; - public RestGetAnomalyDetectorAction( - RestController controller, - AnomalyDetectorProfileRunner profileRunner, - Set allProfileTypeStrs - ) { + public RestGetAnomalyDetectorAction(RestController controller, AnomalyDetectorProfileRunner profileRunner) { this.profileRunner = profileRunner; - this.allProfileTypes = new HashSet(Arrays.asList(ProfileName.values())); - this.allProfileTypeStrs = ProfileName.getNames(); + + List allProfiles = Arrays.asList(ProfileName.values()); + this.allProfileTypes = new HashSet(allProfiles); + this.allProfileTypeStrs = getProfileListStrs(allProfiles); + + List defaultProfiles = Arrays.asList(ProfileName.ERROR, ProfileName.STATE); + this.defaultProfileTypes = new HashSet(defaultProfiles); String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID); controller.registerHandler(RestRequest.Method.GET, path, this); @@ -111,13 +115,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli throw new IllegalStateException(CommonErrorMessages.DISABLED_ERR_MSG); } String detectorId = request.param(DETECTOR_ID); - boolean returnJob = request.paramAsBoolean("job", false); String typesStr = request.param(TYPE); String rawPath = request.rawPath(); if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { + boolean all = request.paramAsBoolean("_all", false); return channel -> profileRunner - .profile(detectorId, getProfileActionListener(channel, detectorId), getProfilesToCollect(typesStr)); + .profile(detectorId, getProfileActionListener(channel, detectorId), getProfilesToCollect(typesStr, all)); } else { + boolean returnJob = request.paramAsBoolean("job", false); MultiGetRequest.Item adItem = new MultiGetRequest.Item(ANOMALY_DETECTORS_INDEX, detectorId) .version(RestActions.parseVersion(request)); MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); @@ -201,12 +206,18 @@ private RestResponse buildInternalServerErrorResponse(Exception e, String errorM return new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, errorMsg); } - private Set getProfilesToCollect(String typesStr) { - if (Strings.isEmpty(typesStr)) { + private Set getProfilesToCollect(String typesStr, boolean all) { + if (all) { return this.allProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultProfileTypes; } else { Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); return ProfileName.getNames(Sets.intersection(this.allProfileTypeStrs, typesInRequest)); } } + + private Set getProfileListStrs(List profileList) { + return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileAction.java new file mode 100644 index 00000000..4bfbb519 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileAction.java @@ -0,0 +1,35 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import org.elasticsearch.action.ActionType; + +/** + * Profile transport action + */ +public class ProfileAction extends ActionType { + + public static final ProfileAction INSTANCE = new ProfileAction(); + public static final String NAME = "cluster:admin/ad/detector/profile"; + + /** + * Constructor + */ + private ProfileAction() { + super(NAME, ProfileResponse::new); + } + +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileNodeRequest.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileNodeRequest.java new file mode 100644 index 00000000..7f25d397 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileNodeRequest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import org.elasticsearch.action.support.nodes.BaseNodeRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; + +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; + +import java.io.IOException; +import java.util.Set; + +/** + * Class representing a nodes's profile request + */ +public class ProfileNodeRequest extends BaseNodeRequest { + private ProfileRequest request; + + public ProfileNodeRequest(StreamInput in) throws IOException { + super(in); + this.request = new ProfileRequest(in); + } + + /** + * Constructor + * + * @param request profile request + */ + public ProfileNodeRequest(ProfileRequest request) { + this.request = request; + } + + public String getDetectorId() { + return request.getDetectorId(); + } + + /** + * Get the set that tracks which profiles should be retrieved + * + * @return the set that contains the profile names marked for retrieval + */ + public Set getProfilesToBeRetrieved() { + return request.getProfilesToBeRetrieved(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileNodeResponse.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileNodeResponse.java new file mode 100644 index 00000000..df4c7cbb --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileNodeResponse.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import org.elasticsearch.action.support.nodes.BaseNodeResponse; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +/** + * Profile response on a node + */ +public class ProfileNodeResponse extends BaseNodeResponse implements ToXContentFragment { + // filed name in toXContent + static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + static final String SHINGLE_SIZE = "shingle_size"; + + private Map modelSize; + private int shingleSize; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public ProfileNodeResponse(StreamInput in) throws IOException { + super(in); + modelSize = in.readMap(StreamInput::readString, StreamInput::readLong); + shingleSize = in.readInt(); + } + + /** + * Constructor + * + * @param node DiscoveryNode object + * @param modelSize Mapping of model id to its memory consumption in bytes + * @param shingleSize shingle size + */ + public ProfileNodeResponse(DiscoveryNode node, Map modelSize, int shingleSize) { + super(node); + this.modelSize = modelSize; + this.shingleSize = shingleSize; + } + + /** + * Creates a new ProfileNodeResponse object and reads in the profile from an input stream + * + * @param in StreamInput to read from + * @return ProfileNodeResponse object corresponding to the input stream + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public static ProfileNodeResponse readProfiles(StreamInput in) throws IOException { + return new ProfileNodeResponse(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(modelSize, StreamOutput::writeString, StreamOutput::writeLong); + out.writeInt(shingleSize); + } + + /** + * Converts profile to xContent + * + * @param builder XContentBuilder + * @param params Params + * @return XContentBuilder + * @throws IOException thrown by builder for invalid field + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(MODEL_SIZE_IN_BYTES); + for (Map.Entry entry : modelSize.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + + builder.field(SHINGLE_SIZE, shingleSize); + + return builder; + } + + public Map getModelSize() { + return modelSize; + } + + public int getShingleSize() { + return shingleSize; + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileRequest.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileRequest.java new file mode 100644 index 00000000..531312a1 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import org.elasticsearch.action.support.nodes.BaseNodesRequest; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; + +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +/** + * implements a request to obtain profiles about an AD detector + */ +public class ProfileRequest extends BaseNodesRequest { + + private Set profilesToBeRetrieved; + private String detectorId; + + public ProfileRequest(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + profilesToBeRetrieved = new HashSet(); + if (size != 0) { + for (int i = 0; i < size; i++) { + profilesToBeRetrieved.add(in.readEnum(ProfileName.class)); + } + } + detectorId = in.readString(); + } + + /** + * Constructor + * + * @param detectorId detector's id + * @param profilesToBeRetrieved profiles to be retrieved + * @param nodes nodes of nodes' profiles to be retrieved + */ + public ProfileRequest(String detectorId, Set profilesToBeRetrieved, DiscoveryNode... nodes) { + super(nodes); + this.detectorId = detectorId; + this.profilesToBeRetrieved = profilesToBeRetrieved; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(profilesToBeRetrieved.size()); + for (ProfileName profile : profilesToBeRetrieved) { + out.writeEnum(profile); + } + out.writeString(detectorId); + } + + public String getDetectorId() { + return detectorId; + } + + /** + * Get the set that tracks which profiles should be retrieved + * + * @return the set that contains the profile names marked for retrieval + */ + public Set getProfilesToBeRetrieved() { + return profilesToBeRetrieved; + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileResponse.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileResponse.java new file mode 100644 index 00000000..cdfd79e8 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileResponse.java @@ -0,0 +1,146 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.support.nodes.BaseNodesResponse; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import com.amazon.opendistroforelasticsearch.ad.constant.CommonName; +import com.amazon.opendistroforelasticsearch.ad.model.ModelProfile; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * This class consists of the aggregated responses from the nodes + */ +public class ProfileResponse extends BaseNodesResponse implements ToXContentFragment { + // filed name in toXContent + static final String COORDINATING_NODE = CommonName.COORDINATING_NODE; + static final String SHINGLE_SIZE = CommonName.SHINGLE_SIZE; + static final String TOTAL_SIZE = CommonName.TOTAL_SIZE_IN_BYTES; + static final String MODELS = CommonName.MODELS; + + private ModelProfile[] modelProfile; + private int shingleSize; + private String coordinatingNode; + private long totalSizeInBytes; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public ProfileResponse(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + modelProfile = new ModelProfile[size]; + for (int i = 0; i < size; i++) { + modelProfile[i] = new ModelProfile(in); + } + shingleSize = in.readVInt(); + coordinatingNode = in.readString(); + totalSizeInBytes = in.readVLong(); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of ProfileNodeResponse from nodes + * @param failures List of failures from nodes + */ + public ProfileResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + totalSizeInBytes = 0L; + List modelProfileList = new ArrayList<>(); + for (ProfileNodeResponse response : nodes) { + String curNodeId = response.getNode().getId(); + if (response.getShingleSize() >= 0) { + coordinatingNode = curNodeId; + shingleSize = response.getShingleSize(); + } + for (Map.Entry entry : response.getModelSize().entrySet()) { + totalSizeInBytes += entry.getValue(); + modelProfileList.add(new ModelProfile(entry.getKey(), entry.getValue(), curNodeId)); + } + + } + if (coordinatingNode == null) { + coordinatingNode = ""; + } + this.modelProfile = modelProfileList.toArray(new ModelProfile[0]); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(modelProfile.length); + for (ModelProfile profile : modelProfile) { + profile.writeTo(out); + } + out.writeVInt(shingleSize); + out.writeString(coordinatingNode); + out.writeVLong(totalSizeInBytes); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ProfileNodeResponse::readProfiles); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(COORDINATING_NODE, coordinatingNode); + builder.field(SHINGLE_SIZE, shingleSize); + builder.field(TOTAL_SIZE, totalSizeInBytes); + builder.startArray(MODELS); + for (ModelProfile profile : modelProfile) { + profile.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + public ModelProfile[] getModelProfile() { + return modelProfile; + } + + public int getShingleSize() { + return shingleSize; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public long getTotalSizeInBytes() { + return totalSizeInBytes; + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTransportAction.java new file mode 100644 index 00000000..5806c5d7 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTransportAction.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.nodes.TransportNodesAction; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This class contains the logic to extract the stats from the nodes + */ +public class ProfileTransportAction extends TransportNodesAction { + + private ModelManager modelManager; + private FeatureManager featureManager; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param modelManager model manager object + * @param featureManager feature manager object + */ + @Inject + public ProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ModelManager modelManager, + FeatureManager featureManager + ) { + super( + ProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ProfileRequest::new, + ProfileNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ProfileNodeResponse.class + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + } + + @Override + protected ProfileResponse newResponse(ProfileRequest request, List responses, List failures) { + return new ProfileResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ProfileNodeRequest newNodeRequest(ProfileRequest request) { + return new ProfileNodeRequest(request); + } + + @Override + protected ProfileNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ProfileNodeResponse(in); + } + + @Override + protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { + String detectorId = request.getDetectorId(); + Set profiles = request.getProfilesToBeRetrieved(); + int shingleSize = -1; + if (profiles.contains(ProfileName.COORDINATING_NODE) || profiles.contains(ProfileName.SHINGLE_SIZE)) { + shingleSize = featureManager.getShingleSize(detectorId); + } + Map modelSize = null; + if (profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(ProfileName.MODELS)) { + modelSize = modelManager.getModelSize(detectorId); + } else { + modelSize = new HashMap<>(); + } + return new ProfileNodeResponse(clusterService.localNode(), modelSize, shingleSize); + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java index 29604746..728f4316 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AbstractADTest.java @@ -26,14 +26,21 @@ import org.apache.logging.log4j.core.appender.AbstractAppender; import org.apache.logging.log4j.core.layout.PatternLayout; import org.apache.logging.log4j.util.StackLocatorUtil; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import test.com.amazon.opendistroforelasticsearch.ad.util.FakeNode; + public class AbstractADTest extends ESTestCase { protected static final Logger LOG = (Logger) LogManager.getLogger(AbstractADTest.class); + // transport test node + protected int nodesCount; + protected FakeNode[] testNodes; + /** * Log4j appender that uses a list to store log messages * @@ -122,4 +129,20 @@ protected static void tearDownThreadPool() { ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); threadPool = null; } + + public void setupTestNodes(Settings settings) { + nodesCount = randomIntBetween(2, 10); + testNodes = new FakeNode[nodesCount]; + for (int i = 0; i < testNodes.length; i++) { + testNodes[i] = new FakeNode("node" + i, threadPool, settings); + } + FakeNode.connectNodes(testNodes); + } + + public void tearDownTestNodes() { + for (FakeNode testNode : testNodes) { + testNode.close(); + } + testNodes = null; + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java index 8f23f7ea..158d1c92 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -17,6 +17,8 @@ import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector.ANOMALY_DETECTORS_INDEX; import static com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetectorJob.ANOMALY_DETECTOR_JOB_INDEX; +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -25,21 +27,28 @@ import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.search.SearchModule; @@ -53,17 +62,38 @@ import com.amazon.opendistroforelasticsearch.ad.model.AnomalyResult; import com.amazon.opendistroforelasticsearch.ad.model.DetectorProfile; import com.amazon.opendistroforelasticsearch.ad.model.DetectorState; +import com.amazon.opendistroforelasticsearch.ad.model.ModelProfile; import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileNodeResponse; +import com.amazon.opendistroforelasticsearch.ad.transport.ProfileResponse; +import com.amazon.opendistroforelasticsearch.ad.util.DiscoveryNodeFilterer; public class AnomalyDetectorProfileRunnerTests extends ESTestCase { private static final Logger LOG = LogManager.getLogger(AnomalyDetectorProfileRunnerTests.class); private AnomalyDetectorProfileRunner runner; private Client client; + private DiscoveryNodeFilterer nodeFilter; private AnomalyDetector detector; private static Set stateOnly; private static Set stateNError; + private static Set modelProfile; private static String error = "No full shingle in current detection window"; + // profile model related + String node1; + String nodeName1; + DiscoveryNode discoveryNode1; + + String node2; + String nodeName2; + DiscoveryNode discoveryNode2; + + long modelSize; + String model1Id; + String model0Id; + + int shingleSize; + @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); @@ -79,6 +109,9 @@ public static void setUpOnce() { stateNError = new HashSet(); stateNError.add(ProfileName.ERROR); stateNError.add(ProfileName.STATE); + modelProfile = new HashSet( + Arrays.asList(ProfileName.SHINGLE_SIZE, ProfileName.MODELS, ProfileName.COORDINATING_NODE, ProfileName.TOTAL_SIZE_IN_BYTES) + ); } @Override @@ -86,7 +119,8 @@ public static void setUpOnce() { public void setUp() throws Exception { super.setUp(); client = mock(Client.class); - runner = new AnomalyDetectorProfileRunner(client, xContentRegistry()); + nodeFilter = mock(DiscoveryNodeFilterer.class); + runner = new AnomalyDetectorProfileRunner(client, xContentRegistry(), nodeFilter); } enum JobStatus { @@ -314,4 +348,95 @@ public void testExceptionOnStateFetching() throws IOException, InterruptedExcept }), stateOnly); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); } + + @SuppressWarnings("unchecked") + private void setUpClientExecute() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + node1 = "node1"; + nodeName1 = "nodename1"; + discoveryNode1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + node2 = "node2"; + nodeName2 = "nodename2"; + discoveryNode2 = new DiscoveryNode( + nodeName2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + modelSize = 4456448L; + model1Id = "Pl536HEBnXkDrah03glg_model_rcf_1"; + model0Id = "Pl536HEBnXkDrah03glg_model_rcf_0"; + + shingleSize = 6; + + String clusterName = "test-cluster-name"; + + Map modelSizeMap1 = new HashMap() { + { + put(model1Id, modelSize); + } + }; + + Map modelSizeMap2 = new HashMap() { + { + put(model0Id, modelSize); + } + }; + + LOG.info("hello"); + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse(discoveryNode2, modelSizeMap2, -1); + List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); + List failures = Collections.emptyList(); + ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); + + listener.onResponse(profileResponse); + + return null; + }).when(client).execute(any(), any(), any()); + + } + + public void testProfileModels() throws InterruptedException, IOException { + setUpClientGet(true, JobStatus.ENABLED); + setUpClientExecute(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getDetectorId(), ActionListener.wrap(profileResponse -> { + assertEquals(node1, profileResponse.getCoordinatingNode()); + assertEquals(shingleSize, profileResponse.getShingleSize()); + assertEquals(modelSize * 2, profileResponse.getTotalSizeInBytes()); + assertEquals(2, profileResponse.getModelProfile().length); + for (ModelProfile profile : profileResponse.getModelProfile()) { + assertTrue(node1.equals(profile.getNodeId()) || node2.equals(profile.getNodeId())); + assertEquals(modelSize, profile.getModelSize()); + if (node1.equals(profile.getNodeId())) { + assertEquals(model1Id, profile.getModelId()); + } + if (node2.equals(profile.getNodeId())) { + assertEquals(model0Id, profile.getModelId()); + } + } + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), modelProfile); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java index bdea7b5e..32dc442b 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorRestTestCase.java @@ -134,7 +134,7 @@ public ToXContentObject[] getAnomalyDetector(String detectorId, BasicHeader head ); assertEquals("Unable to get anomaly detector " + detectorId, RestStatus.OK, restStatus(response)); XContentParser parser = createAdParser(XContentType.JSON.xContent(), response.getEntity().getContent()); - XContentParser.Token token = parser.nextToken(); + parser.nextToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); String id = null; @@ -214,15 +214,23 @@ public void updateClusterSettings(String settingKey, Object value) throws Except assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } - public Response getDetectorProfile(String detectorId) throws IOException { + public Response getDetectorProfile(String detectorId, boolean all, String customizedProfile) throws IOException { return TestHelpers .makeRequest( client(), "GET", - TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "/" + RestHandlerUtils.PROFILE, + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "/" + RestHandlerUtils.PROFILE + customizedProfile + "?_all=" + all, null, "", ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); } + + public Response getDetectorProfile(String detectorId) throws IOException { + return getDetectorProfile(detectorId, false, ""); + } + + public Response getDetectorProfile(String detectorId, boolean all) throws IOException { + return getDetectorProfile(detectorId, all, ""); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/rest/AnomalyDetectorRestApiIT.java index c6544b98..d52125e0 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -310,7 +310,7 @@ public void testPreviewAnomalyDetector() throws Exception { } public void testPreviewAnomalyDetectorWhichNotExist() throws Exception { - AnomalyDetector detector = createRandomAnomalyDetector(true, false); + createRandomAnomalyDetector(true, false); AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( randomAlphaOfLength(5), Instant.now().minusSeconds(60 * 10), @@ -885,7 +885,7 @@ public void testStartAdjobWithEmptyFeatures() throws Exception { ); } - public void testProfileAnomalyDetector() throws Exception { + public void testDefaultProfileAnomalyDetector() throws Exception { AnomalyDetector detector = createRandomAnomalyDetector(true, true); updateClusterSettings(EnabledSetting.AD_PLUGIN_ENABLED, false); @@ -898,4 +898,18 @@ public void testProfileAnomalyDetector() throws Exception { Response profileResponse = getDetectorProfile(detector.getDetectorId()); assertEquals("Incorrect profile status", RestStatus.OK, restStatus(profileResponse)); } + + public void testAllProfileAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true); + + Response profileResponse = getDetectorProfile(detector.getDetectorId(), true); + assertEquals("Incorrect profile status", RestStatus.OK, restStatus(profileResponse)); + } + + public void testCustomizedProfileAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true); + + Response profileResponse = getDetectorProfile(detector.getDetectorId(), true, "/models/"); + assertEquals("Incorrect profile status", RestStatus.OK, restStatus(profileResponse)); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java index 7d84537e..d43e61e2 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/AnomalyResultTests.java @@ -121,13 +121,10 @@ import com.google.gson.JsonElement; -import test.com.amazon.opendistroforelasticsearch.ad.util.FakeNode; import test.com.amazon.opendistroforelasticsearch.ad.util.JsonDeserializer; public class AnomalyResultTests extends AbstractADTest { private static Settings settings = Settings.EMPTY; - private FakeNode[] testNodes; - private int nodesCount; private TransportService transportService; private ClusterService clusterService; private ADStateManager stateManager; @@ -149,13 +146,11 @@ public class AnomalyResultTests extends AbstractADTest { @BeforeClass public static void setUpBeforeClass() { setUpThreadPool(AnomalyResultTests.class.getSimpleName()); - settings = Settings.EMPTY; } @AfterClass public static void tearDownAfterClass() { tearDownThreadPool(); - settings = null; } @SuppressWarnings("unchecked") @@ -164,8 +159,8 @@ public static void tearDownAfterClass() { public void setUp() throws Exception { super.setUp(); super.setUpLog4jForJUnit(AnomalyResultTransportAction.class); - setupTestNodes(Settings.EMPTY); - FakeNode.connectNodes(testNodes); + setupTestNodes(settings); + runner = new ColdStartRunner(); transportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; @@ -260,21 +255,10 @@ public void setUp() throws Exception { adStats = new ADStats(indexUtils, normalModelManager, statsMap); } - public void setupTestNodes(Settings settings) { - nodesCount = randomIntBetween(2, 10); - testNodes = new FakeNode[nodesCount]; - for (int i = 0; i < testNodes.length; i++) { - testNodes[i] = new FakeNode("node" + i, threadPool, settings); - } - } - @Override @After public final void tearDown() throws Exception { - for (FakeNode testNode : testNodes) { - testNode.close(); - } - testNodes = null; + tearDownTestNodes(); runner.shutDown(); runner = null; client = null; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileIT.java new file mode 100644 index 00000000..2acd0542 --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileIT.java @@ -0,0 +1,48 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; + +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.ExecutionException; + +@ESIntegTestCase.ClusterScope(transportClientRatio = 0.9) +public class ProfileIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(AnomalyDetectorPlugin.class); + } + + @Override + protected Collection> transportClientPlugins() { + return Collections.singletonList(AnomalyDetectorPlugin.class); + } + + public void testNormalProfile() throws ExecutionException, InterruptedException { + ProfileRequest profileRequest = new ProfileRequest("123", new HashSet()); + + ProfileResponse response = client().execute(ProfileAction.INSTANCE, profileRequest).get(); + assertTrue("getting profile failed", !response.hasFailures()); + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTests.java new file mode 100644 index 00000000..91c59c2e --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTests.java @@ -0,0 +1,263 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import com.amazon.opendistroforelasticsearch.ad.common.exception.JsonPathNotFoundException; +import com.amazon.opendistroforelasticsearch.ad.model.ModelProfile; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; + +import org.elasticsearch.Version; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; +import org.junit.Test; +import test.com.amazon.opendistroforelasticsearch.ad.util.JsonDeserializer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; + +public class ProfileTests extends ESTestCase { + String node1, nodeName1, clusterName; + String node2, nodeName2; + Map clusterStats; + DiscoveryNode discoveryNode1, discoveryNode2; + long modelSize; + String model1Id; + String model0Id; + String detectorId; + int shingleSize; + Map modelSizeMap1, modelSizeMap2; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clusterName = "test-cluster-name"; + + node1 = "node1"; + nodeName1 = "nodename1"; + discoveryNode1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + node2 = "node2"; + nodeName2 = "nodename2"; + discoveryNode2 = new DiscoveryNode( + nodeName2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + clusterStats = new HashMap<>(); + + modelSize = 4456448L; + model1Id = "Pl536HEBnXkDrah03glg_model_rcf_1"; + model0Id = "Pl536HEBnXkDrah03glg_model_rcf_0"; + detectorId = "123"; + shingleSize = 6; + + modelSizeMap1 = new HashMap() { + { + put(model1Id, modelSize); + } + }; + + modelSizeMap2 = new HashMap() { + { + put(model0Id, modelSize); + } + }; + } + + @Test + public void testProfileNodeRequest() throws IOException { + + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); + ProfileRequest ProfileRequest = new ProfileRequest(detectorId, profilesToRetrieve); + ProfileNodeRequest ProfileNodeRequest = new ProfileNodeRequest(ProfileRequest); + assertEquals("ProfileNodeRequest has the wrong detector id", ProfileNodeRequest.getDetectorId(), detectorId); + assertEquals("ProfileNodeRequest has the wrong ProfileRequest", ProfileNodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); + + // Test serialization + BytesStreamOutput output = new BytesStreamOutput(); + ProfileNodeRequest.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileNodeRequest nodeRequest = new ProfileNodeRequest(streamInput); + assertEquals("serialization has the wrong detector id", nodeRequest.getDetectorId(), detectorId); + assertEquals("serialization has the wrong ProfileRequest", nodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); + + } + + @Test + public void testProfileNodeResponse() throws IOException, JsonPathNotFoundException { + + // Test serialization + ProfileNodeResponse profileNodeResponse = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize); + BytesStreamOutput output = new BytesStreamOutput(); + profileNodeResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileNodeResponse readResponse = ProfileNodeResponse.readProfiles(streamInput); + assertEquals("serialization has the wrong model size", readResponse.getModelSize(), profileNodeResponse.getModelSize()); + assertEquals("serialization has the wrong shingle size", readResponse.getShingleSize(), profileNodeResponse.getShingleSize()); + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + profileNodeResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + for (Map.Entry profile : modelSizeMap1.entrySet()) { + assertEquals( + "toXContent has the wrong model size", + JsonDeserializer.getLongValue(json, ProfileNodeResponse.MODEL_SIZE_IN_BYTES, profile.getKey()), + profile.getValue().longValue() + ); + } + + assertEquals( + "toXContent has the wrong shingle size", + JsonDeserializer.getIntValue(json, ProfileNodeResponse.SHINGLE_SIZE), + shingleSize + ); + } + + @Test + public void testProfileRequest() throws IOException { + String detectorId = "123"; + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve); + + // Test Serialization + BytesStreamOutput output = new BytesStreamOutput(); + profileRequest.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileRequest readRequest = new ProfileRequest(streamInput); + assertEquals( + "Serialization has the wrong profiles to be retrieved", + readRequest.getProfilesToBeRetrieved(), + profileRequest.getProfilesToBeRetrieved() + ); + assertEquals("Serialization has the wrong detector id", readRequest.getDetectorId(), profileRequest.getDetectorId()); + } + + @Test + public void testProfileResponse() throws IOException, JsonPathNotFoundException { + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(discoveryNode1, modelSizeMap1, shingleSize); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse(discoveryNode2, modelSizeMap2, -1); + List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); + List failures = Collections.emptyList(); + ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); + + assertEquals(node1, profileResponse.getCoordinatingNode()); + assertEquals(shingleSize, profileResponse.getShingleSize()); + assertEquals(modelSize * 2, profileResponse.getTotalSizeInBytes()); + assertEquals(2, profileResponse.getModelProfile().length); + for (ModelProfile profile : profileResponse.getModelProfile()) { + assertTrue(node1.equals(profile.getNodeId()) || node2.equals(profile.getNodeId())); + assertEquals(modelSize, profile.getModelSize()); + if (node1.equals(profile.getNodeId())) { + assertEquals(model1Id, profile.getModelId()); + } + if (node2.equals(profile.getNodeId())) { + assertEquals(model0Id, profile.getModelId()); + } + } + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + profileResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + logger.info("JSON: " + json); + + assertEquals( + "toXContent has the wrong coordinating node", + node1, + JsonDeserializer.getTextValue(json, ProfileResponse.COORDINATING_NODE) + ); + assertEquals( + "toXContent has the wrong shingle size", + shingleSize, + JsonDeserializer.getLongValue(json, ProfileResponse.SHINGLE_SIZE) + ); + assertEquals("toXContent has the wrong total size", modelSize * 2, JsonDeserializer.getLongValue(json, ProfileResponse.TOTAL_SIZE)); + + JsonArray modelsJson = JsonDeserializer.getArrayValue(json, ProfileResponse.MODELS); + + for (int i = 0; i < modelsJson.size(); i++) { + JsonElement element = modelsJson.get(i); + assertTrue( + "toXContent has the wrong model id", + JsonDeserializer.getTextValue(element, ModelProfile.MODEL_ID).equals(model1Id) + || JsonDeserializer.getTextValue(element, ModelProfile.MODEL_ID).equals(model0Id) + ); + + assertEquals( + "toXContent has the wrong model size", + JsonDeserializer.getLongValue(element, ModelProfile.MODEL_SIZE_IN_BYTES), + modelSize + ); + + if (JsonDeserializer.getTextValue(element, ModelProfile.MODEL_ID).equals(model1Id)) { + assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfile.NODE_ID), node1); + } else { + assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfile.NODE_ID), node2); + } + + } + + // Test Serialization + BytesStreamOutput output = new BytesStreamOutput(); + + profileResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileResponse readResponse = new ProfileResponse(streamInput); + + builder = jsonBuilder(); + String readJson = Strings.toString(readResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject()); + assertEquals("Serialization fails", readJson, json); + } +} diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTransportActionTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTransportActionTests.java new file mode 100644 index 00000000..58630bf7 --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/transport/ProfileTransportActionTests.java @@ -0,0 +1,125 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.transport; + +import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager; +import com.amazon.opendistroforelasticsearch.ad.ml.ModelManager; +import com.amazon.opendistroforelasticsearch.ad.model.ProfileName; + +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.transport.TransportService; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ProfileTransportActionTests extends ESIntegTestCase { + private ProfileTransportAction action; + private String detectorId = "Pl536HEBnXkDrah03glg"; + String node1, nodeName1; + DiscoveryNode discoveryNode1; + Set profilesToRetrieve = new HashSet(); + private int shingleSize = 6; + private long modelSize = 4456448L; + private String modelId = "Pl536HEBnXkDrah03glg_model_rcf_1"; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + ModelManager modelManager = mock(ModelManager.class); + FeatureManager featureManager = mock(FeatureManager.class); + + when(featureManager.getShingleSize(any(String.class))).thenReturn(shingleSize); + + Map modelSizes = new HashMap<>(); + modelSizes.put(modelId, modelSize); + when(modelManager.getModelSize(any(String.class))).thenReturn(modelSizes); + + action = new ProfileTransportAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + modelManager, + featureManager + ); + + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); + } + + @Test + public void testNewResponse() { + DiscoveryNode node = clusterService().localNode(); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, node); + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(node, new HashMap<>(), shingleSize); + List profileNodeResponses = Arrays.asList(profileNodeResponse1); + List failures = new ArrayList<>(); + + ProfileResponse profileResponse = action.newResponse(profileRequest, profileNodeResponses, failures); + assertEquals(node.getId(), profileResponse.getCoordinatingNode()); + } + + @Test + public void testNewNodeRequest() { + + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve); + + ProfileNodeRequest profileNodeRequest1 = new ProfileNodeRequest(profileRequest); + ProfileNodeRequest profileNodeRequest2 = action.newNodeRequest(profileRequest); + + assertEquals(profileNodeRequest1.getDetectorId(), profileNodeRequest2.getDetectorId()); + assertEquals(profileNodeRequest2.getProfilesToBeRetrieved(), profileNodeRequest2.getProfilesToBeRetrieved()); + } + + @Test + public void testNodeOperation() { + + DiscoveryNode nodeId = clusterService().localNode(); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, nodeId); + + ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(shingleSize, response.getShingleSize()); + assertEquals(0, response.getModelSize().size()); + + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.TOTAL_SIZE_IN_BYTES); + + profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, nodeId); + response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(-1, response.getShingleSize()); + assertEquals(1, response.getModelSize().size()); + assertEquals(modelSize, response.getModelSize().get(modelId).longValue()); + } +} diff --git a/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/JsonDeserializer.java b/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/JsonDeserializer.java index bd2f9fcc..2788f59a 100644 --- a/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/JsonDeserializer.java +++ b/src/test/java/test/com/amazon/opendistroforelasticsearch/ad/util/JsonDeserializer.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. @@ -333,11 +333,12 @@ public static String getTextValue(String jsonString, String paths, boolean retur } /** - * Search an int number inside a JSON string matching the input path expression + * Search an array inside a JSON string matching the input path expression and convert each element using a function * * @param jsonString an encoded JSON string - * @param paths path fragments - * @return list of double + * @param function function to parse each element + * @param paths path fragments + * @return an array of values * @throws JsonPathNotFoundException if json path is invalid * @throws IOException if the underlying input source has problems * during parsing @@ -358,6 +359,24 @@ public static T[] getArrayValue(String jsonString, Function throw new JsonPathNotFoundException(); } + /** + * Search an array inside a JSON string matching the input path expression + * + * @param jsonString an encoded JSON string + * @param paths path fragments + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + @SuppressWarnings("unchecked") + public static JsonArray getArrayValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null && jsonNode.isJsonArray()) { + return jsonNode.getAsJsonArray(); + } + throw new JsonPathNotFoundException(); + } + /** * Search a double number inside a JSON string matching the input path * expression