From 9eed2ff1ee1c3d7036bae96bfea82c07979d2e3b Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Thu, 6 Jun 2024 01:24:42 +0800 Subject: [PATCH] change request type to ActionRequest BaseGetConfigTransportAction to fix class cast exception (#1221) Signed-off-by: Hailong Cui --- .../opensearch/ad/model/AnomalyDetector.java | 12 ++--- .../BaseGetConfigTransportAction.java | 10 ++-- .../ad/model/AnomalyDetectorTests.java | 49 +++++++++++++++++++ 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index cab54f30a..e09c1bf96 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -271,9 +271,9 @@ public AnomalyDetector(StreamInput input) throws IOException { } else { this.imputationOption = null; } - this.recencyEmphasis = input.readInt(); - this.seasonIntervals = input.readInt(); - this.historyIntervals = input.readInt(); + this.recencyEmphasis = input.readOptionalInt(); + this.seasonIntervals = input.readOptionalInt(); + this.historyIntervals = input.readOptionalInt(); if (input.readBoolean()) { this.rules = input.readList(Rule::new); } @@ -333,9 +333,9 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } - output.writeInt(recencyEmphasis); - output.writeInt(seasonIntervals); - output.writeInt(historyIntervals); + output.writeOptionalInt(recencyEmphasis); + output.writeOptionalInt(seasonIntervals); + output.writeOptionalInt(historyIntervals); if (rules != null) { output.writeBoolean(true); output.writeList(rules); diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java index 9e1ece6f2..f3fe74608 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java @@ -24,6 +24,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetRequest; @@ -74,7 +75,7 @@ import com.google.common.collect.Sets; public abstract class BaseGetConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager, ConfigType extends Config, EntityProfileActionType extends ActionType, EntityProfileRunnerType extends EntityProfileRunner, TaskProfileType extends TaskProfile, ConfigProfileType extends ConfigProfile, ProfileActionType extends ActionType, TaskProfileRunnerType extends TaskProfileRunner, ProfileRunnerType extends ProfileRunner> - extends HandledTransportAction { + extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(BaseGetConfigTransportAction.class); @@ -156,8 +157,9 @@ public BaseGetConfigTransportAction( } @Override - public void doExecute(Task task, GetConfigRequest request, ActionListener actionListener) { - String configID = request.getConfigID(); + public void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + GetConfigRequest getConfigRequest = GetConfigRequest.fromActionRequest(request); + String configID = getConfigRequest.getConfigID(); User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_FORECASTER); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -166,7 +168,7 @@ public void doExecute(Task task, GetConfigRequest request, ActionListener getExecute(request, listener), + (config) -> getExecute(getConfigRequest, listener), client, clusterService, xContentRegistry, diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index ebd3fecaf..bfdb2a84c 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -17,14 +17,28 @@ import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; @@ -901,4 +915,39 @@ public void testParseAnomalyDetector_withCustomIndex_withCustomResultIndexTTL() AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertEquals(30, (int) parsedDetector.getCustomResultIndexTTL()); } + + public void testSerializeAndDeserializeAnomalyDetector() throws IOException { + // register writer and reader for type Feature + Writeable.WriteableRegistry.registerWriter(Feature.class, (o, v) -> { + o.writeByte((byte) 23); + ((Feature) v).writeTo(o); + }); + Writeable.WriteableRegistry.registerReader((byte) 23, Feature::new); + + // write to streamOutput + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + detector.writeTo(bytesStreamOutput); + + // register namedWriteables + List namedWriteables = new ArrayList<>(); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, BoolQueryBuilder.NAME, BoolQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RangeQueryBuilder.NAME, RangeQueryBuilder::new)); + namedWriteables + .add( + new NamedWriteableRegistry.Entry( + AggregationBuilder.class, + ValueCountAggregationBuilder.NAME, + ValueCountAggregationBuilder::new + ) + ); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + StreamInput input = new NamedWriteableAwareStreamInput(streamInput, new NamedWriteableRegistry(namedWriteables)); + + AnomalyDetector deserializedDetector = new AnomalyDetector(input); + Assert.assertEquals(deserializedDetector, detector); + Assert.assertEquals(deserializedDetector.getSeasonIntervals(), detector.getSeasonIntervals()); + } }