diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index e3e2e19dc..acda20408 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -13,6 +13,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.cluster.routing.Preference; +import org.opensearch.commons.alerting.model.Monitor; import org.opensearch.core.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionResponse; @@ -183,13 +184,15 @@ public List getRestHandlers(Settings settings, new RestSearchCorrelationRuleAction(), new RestIndexCustomLogTypeAction(), new RestSearchCustomLogTypeAction(), - new RestDeleteCustomLogTypeAction() + new RestDeleteCustomLogTypeAction(), + new RestExecuteStreamingDetectorsAction(settings) ); } @Override public List getNamedXContent() { return List.of( + Monitor.Companion.getXCONTENT_REGISTRY(), Detector.XCONTENT_REGISTRY, DetectorInput.XCONTENT_REGISTRY, Rule.XCONTENT_REGISTRY, @@ -248,7 +251,8 @@ public List> getSettings() { SecurityAnalyticsSettings.CORRELATION_TIME_WINDOW, SecurityAnalyticsSettings.ENABLE_AUTO_CORRELATIONS, SecurityAnalyticsSettings.DEFAULT_MAPPING_SCHEMA, - SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE + SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE, + SecurityAnalyticsSettings.ENABLE_STREAMING_DETECTORS ); } @@ -279,7 +283,8 @@ public List> getSettings() { new ActionPlugin.ActionHandler<>(SearchCorrelationRuleAction.INSTANCE, TransportSearchCorrelationRuleAction.class), new ActionHandler<>(IndexCustomLogTypeAction.INSTANCE, TransportIndexCustomLogTypeAction.class), new ActionHandler<>(SearchCustomLogTypeAction.INSTANCE, TransportSearchCustomLogTypeAction.class), - new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class) + new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class), + new ActionHandler<>(ExecuteStreamingDetectorsAction.INSTANCE, TransportExecuteStreamingDetectorsAction.class) ); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/ExecuteStreamingDetectorsAction.java b/src/main/java/org/opensearch/securityanalytics/action/ExecuteStreamingDetectorsAction.java new file mode 100644 index 000000000..7f73f642b --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/ExecuteStreamingDetectorsAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionType; +import org.opensearch.action.bulk.BulkResponse; + +public class ExecuteStreamingDetectorsAction extends ActionType { + public static final ExecuteStreamingDetectorsAction INSTANCE = new ExecuteStreamingDetectorsAction(); + public static final String NAME = "cluster:admin/opensearch/securityanalytics/detectors/streaming/execute"; + + public ExecuteStreamingDetectorsAction() { + super(NAME, BulkResponse::new); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/converters/ExecuteStreamingWorkflowRequestConverter.java b/src/main/java/org/opensearch/securityanalytics/converters/ExecuteStreamingWorkflowRequestConverter.java new file mode 100644 index 000000000..8b7d30240 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/converters/ExecuteStreamingWorkflowRequestConverter.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.converters; + +import org.opensearch.common.inject.Inject; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.commons.alerting.action.ExecuteStreamingWorkflowRequest; +import org.opensearch.commons.alerting.model.IdDocPair; +import org.opensearch.commons.alerting.model.StreamingIndex; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.securityanalytics.model.DocData; +import org.opensearch.securityanalytics.model.StreamingDetectorMetadata; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class ExecuteStreamingWorkflowRequestConverter { + private final NamedXContentRegistry xContentRegistry; + + @Inject + public ExecuteStreamingWorkflowRequestConverter(final NamedXContentRegistry xContentRegistry) { + this.xContentRegistry = xContentRegistry; + } + + public ExecuteStreamingWorkflowRequest convert(final StreamingDetectorMetadata streamingDetectorMetadata) { + final List streamingIndices = streamingDetectorMetadata.getIndexToDocData().entrySet().stream() + .map(entry -> createStreamingIndex(entry, streamingDetectorMetadata.getQueryFields())) + .collect(Collectors.toList()); + + return new ExecuteStreamingWorkflowRequest(streamingDetectorMetadata.getWorkflowId(), streamingIndices); + } + + private StreamingIndex createStreamingIndex(final Map.Entry> indexToDocData, final Set fieldNames) { + final List filteredIdDocPairs = getFilteredIdDocPairs(indexToDocData.getValue(), fieldNames); + return new StreamingIndex(indexToDocData.getKey(), filteredIdDocPairs); + } + + private List getFilteredIdDocPairs(final List indexToDocData, final Set fieldNames) { + return indexToDocData.stream() + .map(DocData::getIdDocPair) + .map(idDocPair -> { + final String docId = idDocPair.getDocId(); + final BytesReference filteredDocument = getFilteredDocument(idDocPair.getDocument(), fieldNames); + return new IdDocPair(docId, filteredDocument); + }) + .collect(Collectors.toList()); + } + + // TODO - this logic is consuming ~40% of the CPU. Is there a more efficient way to filter the docs? + private BytesReference getFilteredDocument(final BytesReference document, final Set fieldNames) { + try { + final XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, LoggingDeprecationHandler.INSTANCE, document.streamInput()); + final Map documentAsMap = xcp.map(); + final Map filteredDocumentAsMap = XContentMapValues.filter(documentAsMap, fieldNames.toArray(String[]::new), new String[0]); + + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.map(filteredDocumentAsMap); + return BytesReference.bytes(builder); + } catch (final Exception e) { + throw new MapperParsingException("Exception parsing document to map", e); + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/converters/IndexNameToDocDataConverter.java b/src/main/java/org/opensearch/securityanalytics/converters/IndexNameToDocDataConverter.java new file mode 100644 index 000000000..07ff0ff80 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/converters/IndexNameToDocDataConverter.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.converters; + +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.commons.alerting.model.IdDocPair; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.model.DocData; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +public class IndexNameToDocDataConverter { + public Map> convert(final BulkRequest bulkRequest, final BulkResponse bulkResponse) { + if (bulkRequest.requests().size() != bulkResponse.getItems().length) { + throw new SecurityAnalyticsException( + "BulkRequest item length did not match BulkResponse item length. Unable to proceed.", + RestStatus.INTERNAL_SERVER_ERROR, + null + ); + } + + final Map> indexToDocData = new HashMap<>(); + IntStream.range(0, bulkRequest.requests().size()).forEach(requestIndex -> { + final DocWriteRequest request = bulkRequest.requests().get(requestIndex); + final BulkItemResponse response = bulkResponse.getItems()[requestIndex]; + + // No work for SAP to do if doc is being deleted or DocWriteRequest failed + if (isDeleteOperation(request) || response.isFailed()) { + return; + } + + indexToDocData.putIfAbsent(request.index(), new ArrayList<>()); + final BytesReference document = getDocument(request); + final String docId = response.getId(); + final IdDocPair idDocPair = new IdDocPair(docId, document); + final DocData docData = new DocData(idDocPair, requestIndex); + + indexToDocData.get(request.index()).add(docData); + }); + + return indexToDocData; + } + + private boolean isDeleteOperation(final DocWriteRequest docWriteRequest) { + return DocWriteRequest.OpType.DELETE.equals(docWriteRequest.opType()); + } + + private BytesReference getDocument(final DocWriteRequest docWriteRequest) { + switch (docWriteRequest.opType()) { + case CREATE: + case INDEX: return ((IndexRequest) docWriteRequest).source(); + case UPDATE: return ((UpdateRequest) docWriteRequest).doc().source(); + default: throw new UnsupportedOperationException("No handler for operation type: " + docWriteRequest.opType()); + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/converters/StreamingDetectorMetadataConverter.java b/src/main/java/org/opensearch/securityanalytics/converters/StreamingDetectorMetadataConverter.java new file mode 100644 index 000000000..a62c0fa2d --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/converters/StreamingDetectorMetadataConverter.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.converters; + +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.DocData; +import org.opensearch.securityanalytics.model.StreamingDetectorMetadata; +import org.opensearch.securityanalytics.validators.StreamingDetectorValidators; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class StreamingDetectorMetadataConverter { + public List convert(final List detectors, final Map> indexToDocData) { + return detectors.stream() + .peek(StreamingDetectorValidators::validateDetector) + .filter(Detector::isStreamingDetector) + .filter(detector -> doesDetectorHaveIndexAsInput(detector, indexToDocData.keySet())) + .map(detector -> createStreamingDetectorMetadata(detector, indexToDocData)) + .collect(Collectors.toList()); + } + + // TODO - some edge cases here since index patterns and IndexRequests directly to a write index are not considered + private boolean doesDetectorHaveIndexAsInput(final Detector detector, final Set indexNames) { + final DetectorInput detectorInput = detector.getInputs().get(0); + return detectorInput.getIndices().stream().anyMatch(indexNames::contains); + } + + private StreamingDetectorMetadata createStreamingDetectorMetadata(final Detector detector, + final Map> indexToDocData) { + final Map> indexToDocDataForDetectorIndices = getIndexToDocDataForDetectorIndices( + detector.getInputs().get(0).getIndices(), indexToDocData); + + return new StreamingDetectorMetadata( + detector.getName(), + indexToDocDataForDetectorIndices, + detector.getWorkflowIds().get(0), + detector.getMonitorIds().get(0) + ); + } + + private Map> getIndexToDocDataForDetectorIndices(final List detectorIndices, + final Map> indexToDocData) { + return indexToDocData.entrySet().stream() + .filter(entry -> detectorIndices.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java index e50af189a..a17583df7 100644 --- a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java +++ b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java @@ -77,7 +77,7 @@ public void createMappingAction(String indexName, String logType, String aliasMa // since you can't update documents in non-write indices String index = indexName; boolean shouldUpsertIndexTemplate = IndexUtils.isConcreteIndex(indexName, this.clusterService.state()) == false; - if (IndexUtils.isDataStream(indexName, this.clusterService.state())) { + if (IndexUtils.isDataStream(indexName, this.clusterService.state()) || IndexUtils.isAlias(indexName, this.clusterService.state())) { String writeIndex = IndexUtils.getWriteIndex(indexName, this.clusterService.state()); if (writeIndex != null) { index = writeIndex; diff --git a/src/main/java/org/opensearch/securityanalytics/model/Detector.java b/src/main/java/org/opensearch/securityanalytics/model/Detector.java index 46d3457a2..95b562fcd 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Detector.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Detector.java @@ -24,14 +24,11 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Objects; -import java.util.stream.Collectors; - public class Detector implements Writeable, ToXContentObject { private static final Logger log = LogManager.getLogger(Detector.class); @@ -62,6 +59,7 @@ public class Detector implements Writeable, ToXContentObject { private static final String ALERTS_HISTORY_INDEX_PATTERN = "alert_history_index_pattern"; private static final String FINDINGS_INDEX = "findings_index"; private static final String FINDINGS_INDEX_PATTERN = "findings_index_pattern"; + private static final String STREAMING_DETECTOR_FIELD = "streaming_detector"; public static final String DETECTORS_INDEX = ".opensearch-sap-detectors-config"; @@ -115,13 +113,16 @@ public class Detector implements Writeable, ToXContentObject { private String findingsIndexPattern; + private Boolean streamingDetector; + private final String type; public Detector(String id, Long version, String name, Boolean enabled, Schedule schedule, Instant lastUpdateTime, Instant enabledTime, String logType, User user, List inputs, List triggers, List monitorIds, String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern, - String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, List workflowIds) { + String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, List workflowIds, + Boolean streamingDetector) { this.type = DETECTOR_TYPE; this.id = id != null ? id : NO_ID; @@ -144,6 +145,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule this.ruleIdMonitorIdMap = rulePerMonitor; this.logType = logType; this.workflowIds = workflowIds != null ? workflowIds : null; + this.streamingDetector = streamingDetector; if (enabled) { Objects.requireNonNull(enabledTime); @@ -171,7 +173,8 @@ public Detector(StreamInput sin) throws IOException { sin.readOptionalString(), sin.readOptionalString(), sin.readMap(StreamInput::readString, StreamInput::readString), - sin.readStringList() + sin.readStringList(), + sin.readBoolean() ); } @@ -214,6 +217,8 @@ public void writeTo(StreamOutput out) throws IOException { if (workflowIds != null) { out.writeStringCollection(workflowIds); } + + out.writeBoolean(streamingDetector); } public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { @@ -283,10 +288,12 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten builder.field(FINDINGS_INDEX, findingsIndex); builder.field(FINDINGS_INDEX_PATTERN, findingsIndexPattern); + builder.field(STREAMING_DETECTOR_FIELD, streamingDetector); if (params.paramAsBoolean("with_type", false)) { builder.endObject(); } + return builder.endObject(); } @@ -331,6 +338,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws String findingsIndex = null; String findingsIndexPattern = null; + Boolean streamingDetector = null; + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = xcp.currentName(); @@ -427,6 +436,9 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws case FINDINGS_INDEX_PATTERN: findingsIndexPattern = xcp.text(); break; + case STREAMING_DETECTOR_FIELD: + streamingDetector = xcp.booleanValue(); + break; default: xcp.skipChildren(); } @@ -462,7 +474,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws findingsIndex, findingsIndexPattern, rulePerMonitor, - workflowIds + workflowIds, + streamingDetector ); } @@ -542,6 +555,10 @@ public List getMonitorIds() { return monitorIds; } + public Boolean isStreamingDetector() { + return streamingDetector; + } + public void setUser(User user) { this.user = user; } @@ -603,6 +620,10 @@ public void setWorkflowIds(List workflowIds) { this.workflowIds = workflowIds; } + public void setStreamingDetector(Boolean streamingDetector) { + this.streamingDetector = streamingDetector; + } + public List getWorkflowIds() { return workflowIds; } diff --git a/src/main/java/org/opensearch/securityanalytics/model/DocData.java b/src/main/java/org/opensearch/securityanalytics/model/DocData.java new file mode 100644 index 000000000..4739c2c14 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/model/DocData.java @@ -0,0 +1,21 @@ +package org.opensearch.securityanalytics.model; + +import org.opensearch.commons.alerting.model.IdDocPair; + +public class DocData { + private final IdDocPair idDocPair; + private final int bulkItemResponseIndex; + + public DocData(final IdDocPair idDocPair, final int bulkItemResponseIndex) { + this.idDocPair = idDocPair; + this.bulkItemResponseIndex = bulkItemResponseIndex; + } + + public IdDocPair getIdDocPair() { + return idDocPair; + } + + public int getBulkItemResponseIndex() { + return bulkItemResponseIndex; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/model/StreamingDetectorMetadata.java b/src/main/java/org/opensearch/securityanalytics/model/StreamingDetectorMetadata.java new file mode 100644 index 000000000..ef4512430 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/model/StreamingDetectorMetadata.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.model; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class StreamingDetectorMetadata { + private final String detectorName; + private final Map> indexToDocData; + private final String workflowId; + private final String monitorId; + private final Set queryFields; + + public StreamingDetectorMetadata(final String detectorName, final Map> indexToDocData, + final String workflowId, final String monitorId) { + this.detectorName = detectorName; + this.indexToDocData = indexToDocData; + this.workflowId = workflowId; + this.monitorId = monitorId; + this.queryFields = new HashSet<>(); + } + + public String getDetectorName() { + return detectorName; + } + + public Map> getIndexToDocData() { + return indexToDocData; + } + + public String getWorkflowId() { + return workflowId; + } + + public String getMonitorId() { + return monitorId; + } + + public Set getQueryFields() { + return queryFields; + } + + public void addQueryFields(final Set queryFieldsToAdd) { + queryFields.addAll(queryFieldsToAdd); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestExecuteStreamingDetectorsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestExecuteStreamingDetectorsAction.java new file mode 100644 index 000000000..247ac1af6 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestExecuteStreamingDetectorsAction.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.resthandler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkShardRequest; +import org.opensearch.action.support.ActiveShardCount; +import org.opensearch.client.Requests; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.action.ExecuteStreamingDetectorsAction; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.rest.RestRequest.Method.POST; + +public class RestExecuteStreamingDetectorsAction extends BaseRestHandler { + private static final Logger log = LogManager.getLogger(ExecuteStreamingDetectorsAction.class); + + private final boolean allowExplicitIndex; + + public RestExecuteStreamingDetectorsAction(Settings settings) { + this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings); + } + + @Override + public String getName() { + return "run_detectors_action"; + } + + @Override + public List routes() { + return List.of( + new Route(POST, String.format(Locale.getDefault(), + "%s/streaming/execute", + SecurityAnalyticsPlugin.DETECTOR_BASE_URI)) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + // The below is copied from https://github.com/opensearch-project/OpenSearch/blob/1f8b62fed81424576184dc9ef1ebe69f5156c904/server/src/main/java/org/opensearch/rest/action/document/RestBulkAction.java#L87 + BulkRequest bulkRequest = Requests.bulkRequest(); + String defaultIndex = request.param("index"); + String defaultRouting = request.param("routing"); + FetchSourceContext defaultFetchSourceContext = FetchSourceContext.parseFromRestRequest(request); + String defaultPipeline = request.param("pipeline"); + String waitForActiveShards = request.param("wait_for_active_shards"); + if (waitForActiveShards != null) { + bulkRequest.waitForActiveShards(ActiveShardCount.parseString(waitForActiveShards)); + } + Boolean defaultRequireAlias = request.paramAsBoolean(DocWriteRequest.REQUIRE_ALIAS, null); + bulkRequest.timeout(request.paramAsTime("timeout", BulkShardRequest.DEFAULT_TIMEOUT)); + bulkRequest.setRefreshPolicy(request.param("refresh")); + bulkRequest.add( + request.requiredContent(), + defaultIndex, + defaultRouting, + defaultFetchSourceContext, + defaultPipeline, + defaultRequireAlias, + allowExplicitIndex, + request.getMediaType() + ); + + return channel -> client.execute(ExecuteStreamingDetectorsAction.INSTANCE, bulkRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java index f7edb182f..0b4d49444 100644 --- a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java +++ b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java @@ -151,4 +151,13 @@ public class SecurityAnalyticsSettings { "ecs", Setting.Property.NodeScope, Setting.Property.Dynamic ); + + /** + * Settings for streaming detectors + */ + public static final Setting ENABLE_STREAMING_DETECTORS = Setting.boolSetting( + "plugins.security_analytics.streaming_detectors_enabled", + false, + Setting.Property.NodeScope, Setting.Property.Dynamic + ); } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportExecuteStreamingDetectorsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportExecuteStreamingDetectorsAction.java new file mode 100644 index 000000000..875edc9dd --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportExecuteStreamingDetectorsAction.java @@ -0,0 +1,352 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.lucene.uid.Versions; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.ExecuteStreamingWorkflowRequest; +import org.opensearch.commons.alerting.action.ExecuteStreamingWorkflowResponse; +import org.opensearch.commons.alerting.action.GetMonitorRequest; +import org.opensearch.commons.alerting.action.GetMonitorResponse; +import org.opensearch.commons.alerting.model.DocLevelMonitorInput; +import org.opensearch.commons.alerting.model.DocLevelQuery; +import org.opensearch.commons.alerting.model.Monitor; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.securityanalytics.action.ExecuteStreamingDetectorsAction; +import org.opensearch.securityanalytics.converters.ExecuteStreamingWorkflowRequestConverter; +import org.opensearch.securityanalytics.converters.IndexNameToDocDataConverter; +import org.opensearch.securityanalytics.converters.StreamingDetectorMetadataConverter; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DocData; +import org.opensearch.securityanalytics.model.StreamingDetectorMetadata; +import org.opensearch.securityanalytics.util.DetectorUtils; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class TransportExecuteStreamingDetectorsAction extends HandledTransportAction implements SecureTransportAction { + private static final Logger log = LogManager.getLogger(TransportExecuteStreamingDetectorsAction.class); + + private final ClusterService clusterService; + + private final Settings settings; + + private final Client client; + + private final ThreadPool threadPool; + + private final NamedXContentRegistry xContentRegistry; + + private final TransportSearchDetectorAction transportSearchDetectorAction; + + private final IndexNameToDocDataConverter indexNameToDocDataConverter; + + private final StreamingDetectorMetadataConverter streamingDetectorMetadataConverter; + + private final ExecuteStreamingWorkflowRequestConverter executeStreamingWorkflowRequestConverter; + + @Inject + public TransportExecuteStreamingDetectorsAction(final TransportService transportService, + final Client client, + final ClusterService clusterService, + final Settings settings, + final ActionFilters actionFilters, + final NamedXContentRegistry xContentRegistry, + final TransportSearchDetectorAction transportSearchDetectorAction, + final IndexNameToDocDataConverter indexNameToDocDataConverter, + final StreamingDetectorMetadataConverter streamingDetectorMetadataConverter, + final ExecuteStreamingWorkflowRequestConverter executeStreamingWorkflowRequestConverter) { + super(ExecuteStreamingDetectorsAction.NAME, transportService, actionFilters, BulkRequest::new); + this.client = client; + this.clusterService = clusterService; + this.settings = settings; + this.threadPool = this.client.threadPool(); + this.xContentRegistry = xContentRegistry; + this.transportSearchDetectorAction = transportSearchDetectorAction; + this.indexNameToDocDataConverter = indexNameToDocDataConverter; + this.streamingDetectorMetadataConverter = streamingDetectorMetadataConverter; + this.executeStreamingWorkflowRequestConverter = executeStreamingWorkflowRequestConverter; + } + + /** + * Executes the following steps sequentially + * 1. Submit the BulkRequest for indexing + * 2. Identify the detectors associated with the indices being written to + * 3. Get the query fields associated with the underlying monitors + * 4. Filter the documents based on the query fields of the underlying monitors + * 5. Pair the filtered documents to their corresponding detector(s) + * 6. Execute the underlying workflows for each relevant detector + * + * If there are any failures in the steps after the BulkRequest is indexed, the corresponding BulkItemResponses + * are updated with a RestStatus of 424 and details about the failure + */ + @Override + protected void doExecute(final Task task, final BulkRequest bulkRequest, final ActionListener listener) { + if (!validateUser()) { + listener.onFailure(SecurityAnalyticsException.wrap( + new OpenSearchStatusException("User is not authorized to to perform this action. Contact administrator", RestStatus.FORBIDDEN) + )); + return; + } + + client.execute(BulkAction.INSTANCE, bulkRequest, new ActionListener<>() { + @Override + public void onResponse(final BulkResponse bulkResponse) { + identifyDetectors(bulkRequest, bulkResponse, listener); + } + + @Override + public void onFailure(final Exception e) { + listener.onFailure(e); + } + }); + } + + private boolean validateUser() { + final User user = readUserFromThreadContext(client.threadPool()); + + // If security is enabled, only allow the admin user to call this API + return user == null || isAdmin(user); + } + + private void identifyDetectors(final BulkRequest bulkRequest, final BulkResponse bulkResponse, final ActionListener listener) { + final Map> indexToDocData = indexNameToDocDataConverter.convert(bulkRequest, bulkResponse); + final SearchRequest listDetectorsRequest = getListDetectorsRequest(); + + client.execute(SearchAction.INSTANCE, listDetectorsRequest, new ActionListener<>() { + @Override + public void onResponse(final SearchResponse searchResponse) { + final List detectors; + try { + detectors = DetectorUtils.getDetectors(searchResponse, xContentRegistry); + } catch (final IOException e) { + handleAllDetectorsFailure(bulkResponse, e); + listener.onResponse(bulkResponse); + return; + } + + getMonitors(indexToDocData, detectors, listener, bulkResponse); + } + + @Override + public void onFailure(final Exception e) { + if (e instanceof IndexNotFoundException) { + log.warn("No detectors configured, skipping streaming detectors workflow"); + listener.onResponse(bulkResponse); + } else { + handleAllDetectorsFailure(bulkResponse, e); + listener.onResponse(bulkResponse); + } + } + }); + } + + private SearchRequest getListDetectorsRequest() { + final SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().size(10000); // TODO - pagination + final SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(Detector.DETECTORS_INDEX); + searchRequest.source(searchSourceBuilder); + + return searchRequest; + } + + private void getMonitors(final Map> indexToDocData, + final List detectors, + final ActionListener listener, + final BulkResponse bulkResponse) { + final List streamingDetectors = streamingDetectorMetadataConverter.convert(detectors, indexToDocData); + if (streamingDetectors.isEmpty()) { + log.debug("No streaming detectors identified for incoming data. Skipping streaming detectors workflow"); + listener.onResponse(bulkResponse); + return; + } + + final Map monitorIdToMetadata = streamingDetectors.stream() + .collect(Collectors.toMap(StreamingDetectorMetadata::getMonitorId, metadata -> metadata)); + + final AtomicInteger getMonitorCounter = new AtomicInteger(0); + // TODO - this pattern will submit a burst of requests if the detector/monitor/workflow count is high. Rate limiting should be applied + monitorIdToMetadata.keySet().forEach(monitorId -> { + final GetMonitorRequest getMonitorRequest = new GetMonitorRequest(monitorId, Versions.MATCH_ANY, RestRequest.Method.GET, null); + AlertingPluginInterface.INSTANCE.getMonitor((NodeClient) client, getMonitorRequest, new ActionListener<>() { + @Override + public void onResponse(final GetMonitorResponse getMonitorResponse) { + populateWorkflowIdToMetadata(monitorId, getMonitorResponse, monitorIdToMetadata, getMonitorCounter, + listener, bulkResponse); + } + + @Override + public void onFailure(final Exception e) { + handleDetectorFailure(bulkResponse, monitorIdToMetadata.get(monitorId), e); + + getMonitorCounter.incrementAndGet(); + if (getMonitorCounter.get() == monitorIdToMetadata.size()) { + listener.onResponse(bulkResponse); + } + } + }); + }); + } + + private void populateWorkflowIdToMetadata(final String monitorId, + final GetMonitorResponse getMonitorResponse, + final Map monitorIdToMetadata, + final AtomicInteger getMonitorCounter, + final ActionListener listener, + final BulkResponse bulkResponse) { + final StreamingDetectorMetadata metadata = monitorIdToMetadata.get(monitorId); + if (isMonitorValidForStreaming(getMonitorResponse)) { + populateQueryFields(getMonitorResponse.getMonitor(), metadata); + + getMonitorCounter.incrementAndGet(); + if (getMonitorCounter.get() == monitorIdToMetadata.size()) { + executeWorkflows(monitorIdToMetadata.values(), listener, bulkResponse); + } + } else { + final String errorMsg = String.format("Monitor with ID %s is invalid for streaming.", monitorId); + final SecurityAnalyticsException exception = new SecurityAnalyticsException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR, null); + handleDetectorFailure(bulkResponse, metadata, exception); + + getMonitorCounter.incrementAndGet(); + if (getMonitorCounter.get() == monitorIdToMetadata.size()) { + listener.onResponse(bulkResponse); + } + } + } + + private boolean isMonitorValidForStreaming(final GetMonitorResponse getMonitorResponse) { + return getMonitorResponse.getMonitor() != null && getMonitorResponse.getMonitor().getInputs().size() == 1; + } + + private void populateQueryFields(final Monitor monitor, final StreamingDetectorMetadata metadata) { + final DocLevelMonitorInput docLevelMonitorInput = (DocLevelMonitorInput) monitor.getInputs().get(0); + final Set fieldNames = docLevelMonitorInput.getQueries().stream() + .map(DocLevelQuery::getQueryFieldNames) + .flatMap(Collection::stream) + .collect(Collectors.toSet()); + + metadata.addQueryFields(fieldNames); + } + + private void executeWorkflows(final Collection streamingDetectorMetadata, + final ActionListener listener, + final BulkResponse bulkResponse) { + final AtomicInteger workflowExecutionCounter = new AtomicInteger(0); + streamingDetectorMetadata.forEach(metadata -> { + final ExecuteStreamingWorkflowRequest executeWorkflowRequest = executeStreamingWorkflowRequestConverter.convert(metadata); + executeWorkflow(executeWorkflowRequest, metadata, workflowExecutionCounter, streamingDetectorMetadata.size(), listener, bulkResponse); + }); + } + + private void executeWorkflow(final ExecuteStreamingWorkflowRequest executeWorkflowRequest, final StreamingDetectorMetadata metadata, + final AtomicInteger workflowExecutionCounter, final int workflowCount, + final ActionListener listener, final BulkResponse bulkResponse) { + AlertingPluginInterface.INSTANCE.executeStreamingWorkflow((NodeClient) client, executeWorkflowRequest, new ActionListener<>() { + @Override + public void onResponse(final ExecuteStreamingWorkflowResponse executeStreamingWorkflowResponse) { + log.debug("Successfully ran workflow with ID {}", executeWorkflowRequest.getWorkflowId()); + workflowExecutionCounter.incrementAndGet(); + + if (workflowExecutionCounter.get() == workflowCount) { + listener.onResponse(bulkResponse); + } + } + + @Override + public void onFailure(final Exception e) { + log.debug("Failed to run workflow with ID {}", executeWorkflowRequest.getWorkflowId()); + handleDetectorFailure(bulkResponse, metadata, e); + + workflowExecutionCounter.incrementAndGet(); + if (workflowExecutionCounter.get() == workflowCount) { + listener.onResponse(bulkResponse); + } + } + }); + } + + private void handleAllDetectorsFailure(final BulkResponse bulkResponse, final Exception exception) { + log.error("Failed to run all detectors", exception); + final String failureMessage = String.format("Failed to run all detectors due to %s.", exception); + + IntStream.range(0, bulkResponse.getItems().length).forEach(i -> { + final BulkItemResponse originalBulkItemResponse = bulkResponse.getItems()[i]; + final BulkItemResponse recreatedBulkItemResponse = recreateBulkItemResponseWithFailure(originalBulkItemResponse, failureMessage); + bulkResponse.getItems()[i] = recreatedBulkItemResponse; + }); + } + + private void handleDetectorFailure(final BulkResponse bulkResponse, final StreamingDetectorMetadata streamingDetectorMetadata, + final Exception exception) { + final String detectorName = streamingDetectorMetadata.getDetectorName(); + log.error("Failed to run detector with name {}", detectorName, exception); + final String failureMessage = String.format("Failed to run detector with name %s due to %s.", detectorName, exception); + + final List failedDocData = streamingDetectorMetadata.getIndexToDocData().values().stream() + .flatMap(Collection::stream) + .collect(Collectors.toList()); + failedDocData.forEach(docData -> { + final BulkItemResponse originalBulkItemResponse = bulkResponse.getItems()[docData.getBulkItemResponseIndex()]; + final BulkItemResponse recreatedBulkItemResponse = recreateBulkItemResponseWithFailure(originalBulkItemResponse, failureMessage); + bulkResponse.getItems()[docData.getBulkItemResponseIndex()] = recreatedBulkItemResponse; + }); + } + + private BulkItemResponse recreateBulkItemResponseWithFailure(final BulkItemResponse originalBulkItemResponse, + final String currentFailureMessage) { + final String index; + final String docId; + final String failureMessage; + + // If a previous failure occurred for this document, the BulkItemResponse will already have a Failure entry + if (originalBulkItemResponse.isFailed()) { + index = originalBulkItemResponse.getFailure().getIndex(); + docId = originalBulkItemResponse.getFailure().getId(); + failureMessage = originalBulkItemResponse.getFailure().getCause().getMessage() + " " + currentFailureMessage; + + } else { + index = originalBulkItemResponse.getResponse().getIndex(); + docId = originalBulkItemResponse.getResponse().getId(); + failureMessage = currentFailureMessage; + } + + final SecurityAnalyticsException failureException = new SecurityAnalyticsException(failureMessage, RestStatus.FAILED_DEPENDENCY, null); + final BulkItemResponse.Failure failure = new BulkItemResponse.Failure(index, docId, failureException, RestStatus.FAILED_DEPENDENCY); + return new BulkItemResponse(originalBulkItemResponse.getItemId(), originalBulkItemResponse.getOpType(), failure); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index c27cc14da..4f59ec2c3 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -112,9 +112,12 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -147,6 +150,8 @@ public class TransportIndexDetectorAction extends HandledTransportAction> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws Exception { + private void createMonitorFromQueries(List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy, + List queryFieldNames) throws Exception { List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( Collectors.toList()); List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( @@ -251,7 +259,7 @@ private void createMonitorFromQueries(List> rulesById, Detect List monitorRequests = new ArrayList<>(); if (!docLevelRules.isEmpty()) { - monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames)); } if (!bucketLevelRules.isEmpty()) { @@ -382,13 +390,15 @@ public void onFailure(Exception e) { log.error("Error saving workflow", e); actionListener.onFailure(e); } - }); + }, + enableStreamingDetectors); } else { actionListener.onResponse(monitorResponses); } } - private void updateMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws Exception { + private void updateMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy, + List queryFieldNames) throws Exception { List monitorsToBeUpdated = new ArrayList<>(); List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( @@ -440,9 +450,9 @@ public void onResponse(Map> ruleFieldMappings) { // Process doc level monitors if (!docLevelRules.isEmpty()) { if (detector.getDocLevelMonitorId() == null) { - monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames)); } else { - monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT)); + monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT, queryFieldNames)); } } @@ -468,9 +478,9 @@ public void onFailure(Exception e) { // Process doc level monitors if (!docLevelRules.isEmpty()) { if (detector.getDocLevelMonitorId() == null) { - monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames)); } else { - monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT)); + monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT, queryFieldNames)); } } @@ -620,11 +630,12 @@ public void onFailure(Exception e) { log.error("Failed to update the workflow"); listener.onFailure(e); } - }); + }, + enableStreamingDetectors); } } - private IndexMonitorRequest createDocLevelMonitorRequest(List> queries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) { + private IndexMonitorRequest createDocLevelMonitorRequest(List> queries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod, List queryFieldNames) { List docLevelMonitorInputs = new ArrayList<>(); List docLevelQueries = new ArrayList<>(); @@ -642,7 +653,7 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List tags.add(rule.getCategory()); tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList())); - DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, actualQuery, tags); + DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, actualQuery, tags, queryFieldNames); docLevelQueries.add(docLevelQuery); } DocLevelMonitorInput docLevelMonitorInput = new DocLevelMonitorInput(detector.getName(), detector.getInputs().get(0).getIndices(), docLevelQueries); @@ -1020,6 +1031,7 @@ void createDetector() { request.getDetector().setFindingsIndex(DetectorMonitorConfig.getFindingsIndex(ruleTopic)); request.getDetector().setFindingsIndexPattern(DetectorMonitorConfig.getFindingsIndexPattern(ruleTopic)); request.getDetector().setRuleIndex(DetectorMonitorConfig.getRuleIndex(ruleTopic)); + request.getDetector().setStreamingDetector(enableStreamingDetectors); User originalContextUser = this.user; log.debug("user from original context is {}", originalContextUser); @@ -1132,6 +1144,7 @@ void onGetResponse(Detector currentDetector, User user) { request.getDetector().setFindingsIndexPattern(DetectorMonitorConfig.getFindingsIndexPattern(ruleTopic)); request.getDetector().setRuleIndex(DetectorMonitorConfig.getRuleIndex(ruleTopic)); request.getDetector().setUser(user); + request.getDetector().setStreamingDetector(enableStreamingDetectors); if (!detector.getInputs().isEmpty()) { try { @@ -1327,11 +1340,7 @@ public void onResponse(SearchResponse response) { } else if (detectorInput.getCustomRules().size() > 0) { onFailures(new OpenSearchStatusException("Custom Rule Index not found", RestStatus.NOT_FOUND)); } else { - if (request.getMethod() == RestRequest.Method.POST) { - createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy()); - } else if (request.getMethod() == RestRequest.Method.PUT) { - updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); - } + upsertMonitorFromQueries(queries, detector, logIndex, listener); } } catch (Exception e) { onFailures(e); @@ -1345,6 +1354,57 @@ public void onFailure(Exception e) { }); } + private void upsertMonitorFromQueries(List> queries, Detector detector, String logIndex, ActionListener> listener) throws Exception { + logger.error("PERF_DEBUG: Fetching alias path pairs to construct rule_field_names"); + long start = System.currentTimeMillis(); + Set ruleFieldNames = new HashSet<>(); + for (Pair query : queries) { + List queryFieldNames = query.getValue().getQueryFieldNames().stream().map(Value::getValue).collect(Collectors.toList()); + ruleFieldNames.addAll(queryFieldNames); + } + + CountDownLatch indexMappingsLatch = new CountDownLatch(1); + client.execute(GetIndexMappingsAction.INSTANCE, new GetIndexMappingsRequest(logIndex), new ActionListener<>() { + @Override + public void onResponse(GetIndexMappingsResponse getMappingsViewResponse) { + try { + List> aliasPathPairs; + + aliasPathPairs = MapperUtils.getAllAliasPathPairs(getMappingsViewResponse.getMappings().get(logIndex)); + for (Pair aliasPathPair : aliasPathPairs) { + if (ruleFieldNames.contains(aliasPathPair.getLeft())) { + ruleFieldNames.remove(aliasPathPair.getLeft()); + ruleFieldNames.add(aliasPathPair.getRight()); + } + } + } catch (Exception e) { + logger.error("Failure in parsing rule field names/aliases while " + + detector.getId() == null ? "creating" : "updating" + + " detector. Not optimizing detector queries with relevant fields", e); + ruleFieldNames.clear(); + } finally { + indexMappingsLatch.countDown(); + } + + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to fetch mappings view response for log index " + logIndex, e); + listener.onFailure(e); + indexMappingsLatch.countDown(); + } + }); + indexMappingsLatch.await(); + long took = System.currentTimeMillis() - start; + log.error("PERF_DEBUG: completed collecting rule_field_names in {} millis", took); + if (request.getMethod() == Method.POST) { + createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy(), new ArrayList<>(ruleFieldNames)); + } else if (request.getMethod() == Method.PUT) { + updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy(), new ArrayList<>(ruleFieldNames)); + } + } + @SuppressWarnings("unchecked") public void importCustomRules(Detector detector, DetectorInput detectorInput, List> queries, ActionListener> listener) { final String logIndex = detectorInput.getIndices().get(0); @@ -1381,11 +1441,7 @@ public void onResponse(SearchResponse response) { queries.add(Pair.of(id, rule)); } - if (request.getMethod() == RestRequest.Method.POST) { - createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy()); - } else if (request.getMethod() == RestRequest.Method.PUT) { - updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); - } + upsertMonitorFromQueries(queries, detector, logIndex, listener); } catch (Exception ex) { onFailures(ex); } @@ -1517,4 +1573,8 @@ private void setFilterByEnabled(boolean filterByEnabled) { private void setEnabledWorkflowUsage(boolean enabledWorkflowUsage) { this.enabledWorkflowUsage = enabledWorkflowUsage; } + + private void setEnableStreamingDetectors(boolean enableStreamingDetectors) { + this.enableStreamingDetectors = enableStreamingDetectors; + } } diff --git a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java index 5ce495b98..9c412c75d 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java +++ b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java @@ -78,7 +78,8 @@ public void upsertWorkflow( RefreshPolicy refreshPolicy, String workflowId, Method method, - ActionListener listener + ActionListener listener, + boolean streamingWorkflow ) { List addedMonitors = addedMonitorResponses != null ? addedMonitorResponses.stream().map(IndexMonitorResponse::getId).collect(Collectors.toList()) : Collections.emptyList(); List updatedMonitors = updatedMonitorResponses != null ? updatedMonitorResponses.stream().map(IndexMonitorResponse::getId).collect(Collectors.toList()) : Collections.emptyList(); @@ -106,7 +107,7 @@ public void upsertWorkflow( IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds, detector, - refreshPolicy, workflowId, method, chainedMonitorFindings, cmfMonitorId); + refreshPolicy, workflowId, method, chainedMonitorFindings, cmfMonitorId, streamingWorkflow); AlertingPluginInterface.INSTANCE.indexWorkflow((NodeClient) client, indexWorkflowRequest, @@ -147,7 +148,7 @@ public void deleteWorkflow(String workflowId, ActionListener monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method, - ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId) { + ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId, boolean streamingWorkflow) { AtomicInteger index = new AtomicInteger(); List delegates = monitorIds.stream().map( monitorId -> { @@ -177,7 +178,8 @@ private IndexWorkflowRequest createWorkflowRequest(List monitorIds, Dete List.of(compositeInput), "security_analytics", Collections.emptyList(), - false + false, + streamingWorkflow ); return new IndexWorkflowRequest( diff --git a/src/main/java/org/opensearch/securityanalytics/validators/StreamingDetectorValidators.java b/src/main/java/org/opensearch/securityanalytics/validators/StreamingDetectorValidators.java new file mode 100644 index 000000000..c0c4a4631 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/validators/StreamingDetectorValidators.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.validators; + +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import java.util.Arrays; +import java.util.function.Predicate; + +public enum StreamingDetectorValidators { + INPUTS_VALIDATOR("inputs", detector -> detector.getInputs().size() == 1), + WORKFLOW_IDS_VALIDATOR("workflows", detector -> detector.getWorkflowIds().size() == 1), + MONITOR_IDS_VALIDATOR("monitors", detector -> detector.getMonitorIds().size() == 1); + + private final String elementName; + private final Predicate validator; + + StreamingDetectorValidators(final String elementName, final Predicate validator) { + this.elementName = elementName; + this.validator = validator; + } + + public static void validateDetector(final Detector detector) { + Arrays.stream(values()).forEach(detectorValidator -> { + final boolean isValid = detectorValidator.validator.test(detector); + if (!isValid) { + final String errorMsg = String.format("Detector with ID %s is invalid for streaming. Invalid element: %s", + detector.getId(), detectorValidator.elementName); + throw new SecurityAnalyticsException( + errorMsg, + RestStatus.INTERNAL_SERVER_ERROR, + null + ); + } + }); + } +} diff --git a/src/main/resources/mappings/finding_mapping.json b/src/main/resources/mappings/finding_mapping.json index fcb2cc152..ae31ee3e5 100644 --- a/src/main/resources/mappings/finding_mapping.json +++ b/src/main/resources/mappings/finding_mapping.json @@ -1,7 +1,7 @@ { "dynamic": "strict", "_meta" : { - "schema_version": 4 + "schema_version": 3 }, "properties": { "schema_version": { @@ -39,6 +39,9 @@ "query": { "type": "text" }, + "query_field_names": { + "type": "keyword" + }, "tags": { "type": "text", "fields" : { diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index 50c861788..1e56fa7b1 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -149,7 +149,7 @@ public static Detector randomDetector(String name, DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of(randomDetectorType()), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of()); triggers.add(trigger); } - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList()); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList(), false); } public static CustomLogType randomCustomLogType(String name, String description, String category, String source) { @@ -235,7 +235,8 @@ public static Detector randomDetectorWithNoUser() { "", "", Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); } diff --git a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java index db366056b..a9896ce4a 100644 --- a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java +++ b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java @@ -50,7 +50,8 @@ public void testIndexDetectorPostResponse() throws IOException { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); IndexDetectorResponse response = new IndexDetectorResponse("1234", 1L, RestStatus.OK, detector); Assert.assertNotNull(response); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java index 78dacd6e1..d250d2eef 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java @@ -65,7 +65,8 @@ public void testGetAlerts_success() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -242,7 +243,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/converters/ExecuteStreamingWorkflowRequestConverterTests.java b/src/test/java/org/opensearch/securityanalytics/converters/ExecuteStreamingWorkflowRequestConverterTests.java new file mode 100644 index 000000000..819c364bb --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/converters/ExecuteStreamingWorkflowRequestConverterTests.java @@ -0,0 +1,195 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.converters; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.commons.alerting.action.ExecuteStreamingWorkflowRequest; +import org.opensearch.commons.alerting.model.IdDocPair; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.mapper.MapperParsingException; +import org.opensearch.securityanalytics.model.StreamingDetectorMetadata; +import org.opensearch.test.OpenSearchTestCase; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.IntStream; + +public class ExecuteStreamingWorkflowRequestConverterTests extends OpenSearchTestCase { + private static final String INDEX_NAME = UUID.randomUUID().toString(); + private static final String DOC_ID = UUID.randomUUID().toString(); + private static final String DOCUMENT_STRING = "{\"field1\":\"value1\",\"field2\":\"value2\"}"; + private static final String WORKFLOW_ID = UUID.randomUUID().toString(); + + @Mock + private NamedXContentRegistry xContentRegistry; + + private ExecuteStreamingWorkflowRequestConverter converter; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + converter = new ExecuteStreamingWorkflowRequestConverter(xContentRegistry); + } + + public void testSingleDetectorSingleIndexSingleDocNoFiltering() { + final BytesReference document = getDocument(DOCUMENT_STRING); + final Map> indexToDocIdPairs = getIndexToIdDocPairs(document); + final Collection metadata = List.of(getStreamingDetectorMetadata(Set.of("field1", "field2"))); + + final List result = converter.convert(metadata, indexToDocIdPairs); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(1, result.get(0).getIndices().size()); + assertEquals(INDEX_NAME, result.get(0).getIndices().get(0).getIndex()); + assertEquals(1, result.get(0).getIndices().get(0).getIdDocPairs().size()); + assertEquals(DOC_ID, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocId()); + assertEquals(document, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + } + + public void testFiltersDocFields() { + final Map> indexToDocIdPairs = getIndexToIdDocPairs(getDocument(DOCUMENT_STRING)); + final Collection metadata = List.of(getStreamingDetectorMetadata(Set.of("field1"))); + + final List result = converter.convert(metadata, indexToDocIdPairs); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(1, result.get(0).getIndices().size()); + assertEquals(INDEX_NAME, result.get(0).getIndices().get(0).getIndex()); + assertEquals(1, result.get(0).getIndices().get(0).getIdDocPairs().size()); + assertEquals(DOC_ID, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocId()); + + final BytesReference filteredDocument = getDocument("{\"field1\":\"value1\"}"); + assertEquals(filteredDocument, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + } + + public void testInvalidDocumentThrowsParsingException() { + final Map> indexToDocIdPairs = getIndexToIdDocPairs(getDocument("invalid doc")); + final Collection metadata = List.of(getStreamingDetectorMetadata(Set.of("field1"))); + + assertThrows(MapperParsingException.class, () -> converter.convert(metadata, indexToDocIdPairs)); + } + + public void testFiltersIndicesNotPartOfDetector() { + final BytesReference document = getDocument(DOCUMENT_STRING); + final Map> indexToDocIdPairs = new HashMap<>(); + indexToDocIdPairs.put(INDEX_NAME, List.of(new IdDocPair(DOC_ID, document))); + indexToDocIdPairs.put(UUID.randomUUID().toString(), List.of(new IdDocPair(DOC_ID, document))); + final Collection metadata = List.of(getStreamingDetectorMetadata(Set.of("field1", "field2"))); + + final List result = converter.convert(metadata, indexToDocIdPairs); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(1, result.get(0).getIndices().size()); + assertEquals(INDEX_NAME, result.get(0).getIndices().get(0).getIndex()); + assertEquals(1, result.get(0).getIndices().get(0).getIdDocPairs().size()); + assertEquals(DOC_ID, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocId()); + assertEquals(document, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + } + + public void testMultipleDocs() { + final String secondDocId = UUID.randomUUID().toString(); + final String secondDocument = "{\"field1\":\"abcdef\",\"field2\":\"value2\"}"; + final Map> indexToDocIdPairs = Map.of( + INDEX_NAME, + List.of(new IdDocPair(DOC_ID, getDocument(DOCUMENT_STRING)), new IdDocPair(secondDocId, getDocument(secondDocument))) + ); + final Collection metadata = List.of(getStreamingDetectorMetadata(Set.of("field1"))); + + final List result = converter.convert(metadata, indexToDocIdPairs); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(1, result.get(0).getIndices().size()); + assertEquals(INDEX_NAME, result.get(0).getIndices().get(0).getIndex()); + assertEquals(2, result.get(0).getIndices().get(0).getIdDocPairs().size()); + assertEquals(DOC_ID, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocId()); + assertEquals(secondDocId, result.get(0).getIndices().get(0).getIdDocPairs().get(1).getDocId()); + + final BytesReference filteredDocument1 = getDocument("{\"field1\":\"value1\"}"); + assertEquals(filteredDocument1, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + final BytesReference filteredDocument2 = getDocument("{\"field1\":\"abcdef\"}"); + assertEquals(filteredDocument2, result.get(0).getIndices().get(0).getIdDocPairs().get(1).getDocument()); + } + + public void testMultipleIndices() { + final String secondIndexName = UUID.randomUUID().toString(); + final String secondDocId = UUID.randomUUID().toString(); + final String secondDocument = "{\"field1\":\"abcdef\",\"field2\":\"value2\"}"; + final Map> indexToDocIdPairs = Map.of( + INDEX_NAME, + List.of(new IdDocPair(DOC_ID, getDocument(DOCUMENT_STRING))), + secondIndexName, + List.of(new IdDocPair(secondDocId, getDocument(secondDocument))) + ); + final StreamingDetectorMetadata metadata = new StreamingDetectorMetadata(List.of(INDEX_NAME, secondIndexName), WORKFLOW_ID, null); + metadata.addQueryFields(Set.of("field1")); + + final List result = converter.convert(List.of(metadata), indexToDocIdPairs); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(2, result.get(0).getIndices().size()); + assertEquals(INDEX_NAME, result.get(0).getIndices().get(0).getIndex()); + assertEquals(1, result.get(0).getIndices().get(0).getIdDocPairs().size()); + assertEquals(DOC_ID, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocId()); + final BytesReference filteredDocument1 = getDocument("{\"field1\":\"value1\"}"); + assertEquals(filteredDocument1, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + + assertEquals(secondIndexName, result.get(0).getIndices().get(1).getIndex()); + assertEquals(1, result.get(0).getIndices().get(1).getIdDocPairs().size()); + assertEquals(secondDocId, result.get(0).getIndices().get(1).getIdDocPairs().get(0).getDocId()); + final BytesReference filteredDocument2 = getDocument("{\"field1\":\"abcdef\"}"); + assertEquals(filteredDocument2, result.get(0).getIndices().get(1).getIdDocPairs().get(0).getDocument()); + } + + public void testMultipleWorkflow() { + final BytesReference document = getDocument(DOCUMENT_STRING); + final Map> indexToDocIdPairs = getIndexToIdDocPairs(document); + + final String workflowId2 = UUID.randomUUID().toString(); + final StreamingDetectorMetadata metadata1 = new StreamingDetectorMetadata(List.of(INDEX_NAME), WORKFLOW_ID, null); + metadata1.addQueryFields(Set.of("field1", "field2")); + final StreamingDetectorMetadata metadata2 = new StreamingDetectorMetadata(List.of(INDEX_NAME), workflowId2, null); + metadata2.addQueryFields(Set.of("field1")); + + final List result = converter.convert(List.of(metadata1, metadata2), indexToDocIdPairs); + assertEquals(2, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(workflowId2, result.get(1).getWorkflowId()); + + IntStream.range(0, 2).forEach(i -> { + assertEquals(1, result.get(i).getIndices().size()); + assertEquals(INDEX_NAME, result.get(i).getIndices().get(0).getIndex()); + assertEquals(1, result.get(i).getIndices().get(0).getIdDocPairs().size()); + assertEquals(DOC_ID, result.get(i).getIndices().get(0).getIdDocPairs().get(0).getDocId()); + }); + + assertEquals(document, result.get(0).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + final BytesReference filteredDocument = getDocument("{\"field1\":\"value1\"}"); + assertEquals(filteredDocument, result.get(1).getIndices().get(0).getIdDocPairs().get(0).getDocument()); + } + + private Map> getIndexToIdDocPairs(final BytesReference document) { + return Map.of(INDEX_NAME, List.of(new IdDocPair(DOC_ID, document))); + } + + private StreamingDetectorMetadata getStreamingDetectorMetadata(final Set queryFields) { + final StreamingDetectorMetadata metadata = new StreamingDetectorMetadata(List.of(INDEX_NAME), WORKFLOW_ID, null); + metadata.addQueryFields(queryFields); + + return metadata; + } + + private BytesReference getDocument(final String docString) { + return BytesReference.fromByteBuffer(ByteBuffer.wrap(docString.getBytes(StandardCharsets.UTF_8))); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/converters/IndexNameToDocDataConverterTests.java b/src/test/java/org/opensearch/securityanalytics/converters/IndexNameToDocDataConverterTests.java new file mode 100644 index 000000000..81914aa6d --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/converters/IndexNameToDocDataConverterTests.java @@ -0,0 +1,170 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.converters; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.commons.alerting.model.IdDocPair; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.mockito.Mockito.when; + +public class IndexNameToDocDataConverterTests extends OpenSearchTestCase { + private static final String INDEX_NAME = UUID.randomUUID().toString(); + private static final String DOC_ID = UUID.randomUUID().toString(); + + @Mock + private BulkRequest bulkRequest; + @Mock + private IndexRequest indexRequest; + @Mock + private BytesReference indexRequestSource; + @Mock + private UpdateRequest updateRequest; + @Mock + private IndexRequest updateRequestIndexRequest; + @Mock + private BytesReference updateRequestSource; + @Mock + private BulkResponse bulkResponse; + @Mock + private BulkItemResponse response; + + private IndexNameToDocDataConverter converter; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + converter = new IndexNameToDocDataConverter(); + + when(response.getId()).thenReturn(DOC_ID); + when(indexRequest.index()).thenReturn(INDEX_NAME); + when(indexRequest.opType()).thenReturn(DocWriteRequest.OpType.INDEX); + when(indexRequest.source()).thenReturn(indexRequestSource); + when(updateRequest.index()).thenReturn(INDEX_NAME); + when(updateRequest.opType()).thenReturn(DocWriteRequest.OpType.UPDATE); + when(updateRequest.doc()).thenReturn(updateRequestIndexRequest); + when(updateRequestIndexRequest.source()).thenReturn(updateRequestSource); + } + + public void testBulkRequestAndResponseLengthsDiffer() { + when(bulkRequest.requests()).thenReturn(getDocWriteRequestList(1)); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(2)); + + assertThrows(SecurityAnalyticsException.class, () -> converter.convert(bulkRequest, bulkResponse)); + } + + public void testFiltersDeleteOperations() { + when(bulkRequest.requests()).thenReturn(getDocWriteRequestList(2)); + when(indexRequest.opType()).thenReturn(DocWriteRequest.OpType.DELETE) + .thenReturn(DocWriteRequest.OpType.INDEX); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(2)); + + final Map> result = converter.convert(bulkRequest, bulkResponse); + assertEquals(1, result.size()); + assertTrue(result.containsKey(INDEX_NAME)); + assertEquals(1, result.get(INDEX_NAME).size()); + } + + public void testFiltersFailedBulkItem() { + when(bulkRequest.requests()).thenReturn(getDocWriteRequestList(2)); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(2)); + when(response.isFailed()).thenReturn(true) + .thenReturn(false); + + final Map> result = converter.convert(bulkRequest, bulkResponse); + assertEquals(1, result.size()); + assertTrue(result.containsKey(INDEX_NAME)); + assertEquals(1, result.get(INDEX_NAME).size()); + } + + public void testCreateOperation() { + when(bulkRequest.requests()).thenReturn(List.of(indexRequest)); + when(indexRequest.opType()).thenReturn(DocWriteRequest.OpType.CREATE); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(1)); + + + final Map> result = converter.convert(bulkRequest, bulkResponse); + validateSingleDocSingleIndexCommons(result); + assertEquals(indexRequestSource, result.get(INDEX_NAME).get(0).getDocument()); + } + + public void testIndexOperation() { + when(bulkRequest.requests()).thenReturn(List.of(indexRequest)); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(1)); + + final Map> result = converter.convert(bulkRequest, bulkResponse); + validateSingleDocSingleIndexCommons(result); + assertEquals(indexRequestSource, result.get(INDEX_NAME).get(0).getDocument()); + } + + public void testUpdateOperation() { + when(bulkRequest.requests()).thenReturn(List.of(updateRequest)); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(1)); + + final Map> result = converter.convert(bulkRequest, bulkResponse); + validateSingleDocSingleIndexCommons(result); + assertEquals(updateRequestSource, result.get(INDEX_NAME).get(0).getDocument()); + } + + public void testMultipleIndicesGenerateUniqueMapEntries() { + final String secondIndex = UUID.randomUUID().toString(); + + when(updateRequest.index()).thenReturn(secondIndex); + when(bulkRequest.requests()).thenReturn(List.of(indexRequest, updateRequest)); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(2)); + + final Map> result = converter.convert(bulkRequest, bulkResponse); + assertEquals(2, result.size()); + assertTrue(result.containsKey(INDEX_NAME)); + assertTrue(result.containsKey(secondIndex)); + assertEquals(1, result.get(INDEX_NAME).size()); + assertEquals(indexRequestSource, result.get(INDEX_NAME).get(0).getDocument()); + assertEquals(1, result.get(secondIndex).size()); + assertEquals(updateRequestSource, result.get(secondIndex).get(0).getDocument()); + } + + public void testMultipleRequestsForSameIndexAddedToMapEntry() { + when(bulkRequest.requests()).thenReturn(List.of(indexRequest, updateRequest)); + when(bulkResponse.getItems()).thenReturn(getBulkItemResponseArray(2)); + + final Map> result = converter.convert(bulkRequest, bulkResponse); + assertEquals(1, result.size()); + assertTrue(result.containsKey(INDEX_NAME)); + assertEquals(2, result.get(INDEX_NAME).size()); + assertEquals(indexRequestSource, result.get(INDEX_NAME).get(0).getDocument()); + assertEquals(updateRequestSource, result.get(INDEX_NAME).get(1).getDocument()); + } + + private void validateSingleDocSingleIndexCommons(final Map> result) { + assertEquals(1, result.size()); + assertTrue(result.containsKey(INDEX_NAME)); + assertEquals(1, result.get(INDEX_NAME).size()); + assertEquals(DOC_ID, result.get(INDEX_NAME).get(0).getDocId()); + } + + private List> getDocWriteRequestList(final int length) { + return IntStream.range(0, length).mapToObj(i -> indexRequest).collect(Collectors.toList()); + } + + private BulkItemResponse[] getBulkItemResponseArray(final int length) { + return IntStream.range(0, length).mapToObj(i -> response).toArray(BulkItemResponse[]::new); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/converters/StreamingDetectorMetadataConverterTests.java b/src/test/java/org/opensearch/securityanalytics/converters/StreamingDetectorMetadataConverterTests.java new file mode 100644 index 000000000..601c64472 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/converters/StreamingDetectorMetadataConverterTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.converters; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.StreamingDetectorMetadata; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +import static org.mockito.Mockito.when; + +public class StreamingDetectorMetadataConverterTests extends OpenSearchTestCase { + private static final String INDEX_NAME = UUID.randomUUID().toString(); + private static final String WORKFLOW_ID = UUID.randomUUID().toString(); + private static final String MONITOR_ID = UUID.randomUUID().toString(); + + @Mock + private Detector detector; + @Mock + private Detector detector2; + @Mock + private DetectorInput detectorInput; + @Mock + private DetectorInput detectorInput2; + + private StreamingDetectorMetadataConverter converter; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + converter = new StreamingDetectorMetadataConverter(); + + when(detector.getInputs()).thenReturn(List.of(detectorInput)); + when(detector.getWorkflowIds()).thenReturn(List.of(WORKFLOW_ID)); + when(detector.getMonitorIds()).thenReturn(List.of(MONITOR_ID)); + when(detector.isStreamingDetector()).thenReturn(true); + when(detectorInput.getIndices()).thenReturn(List.of(INDEX_NAME)); + } + + public void testInvalidDetectorThrows() { + when(detector.getWorkflowIds()).thenReturn(List.of(UUID.randomUUID().toString(), UUID.randomUUID().toString())); + + assertThrows(SecurityAnalyticsException.class, () -> converter.convert(List.of(detector), Collections.emptySet())); + } + + public void testFiltersNonStreamingDetectors() { + when(detector.isStreamingDetector()).thenReturn(false); + + final List result = converter.convert(List.of(detector), Set.of(INDEX_NAME)); + assertTrue(result.isEmpty()); + } + + public void testFiltersNoIndexMatchesDetectors() { + when(detectorInput.getIndices()).thenReturn(List.of(UUID.randomUUID().toString())); + + final List result = converter.convert(List.of(detector), Set.of(INDEX_NAME)); + assertTrue(result.isEmpty()); + } + + public void testDetectorMatch() { + final List result = converter.convert(List.of(detector), Set.of(INDEX_NAME)); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(MONITOR_ID, result.get(0).getMonitorId()); + assertEquals(List.of(INDEX_NAME), result.get(0).getIndices()); + } + + public void testDetectorMatchesOnlyOneIndex() { + final Set indexNames = Set.of(INDEX_NAME, UUID.randomUUID().toString(), UUID.randomUUID().toString()); + final String secondDetectorIndexName = UUID.randomUUID().toString(); + when(detectorInput.getIndices()).thenReturn(List.of(INDEX_NAME, secondDetectorIndexName)); + final List result = converter.convert(List.of(detector), indexNames); + assertEquals(1, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(MONITOR_ID, result.get(0).getMonitorId()); + assertEquals(List.of(INDEX_NAME, secondDetectorIndexName), result.get(0).getIndices()); + } + + public void testMultipleDetectors() { + final String indexName2 = UUID.randomUUID().toString(); + final String workflow2 = UUID.randomUUID().toString(); + final String monitor2 = UUID.randomUUID().toString(); + when(detector2.getInputs()).thenReturn(List.of(detectorInput2)); + when(detectorInput2.getIndices()).thenReturn(List.of(indexName2)); + when(detector2.getWorkflowIds()).thenReturn(List.of(workflow2)); + when(detector2.getMonitorIds()).thenReturn(List.of(monitor2)); + when(detector2.isStreamingDetector()).thenReturn(true); + + final Set indexNames = Set.of(INDEX_NAME, indexName2, UUID.randomUUID().toString()); + final List result = converter.convert(List.of(detector, detector2), indexNames); + assertEquals(2, result.size()); + assertEquals(WORKFLOW_ID, result.get(0).getWorkflowId()); + assertEquals(MONITOR_ID, result.get(0).getMonitorId()); + assertEquals(List.of(INDEX_NAME), result.get(0).getIndices()); + assertEquals(workflow2, result.get(1).getWorkflowId()); + assertEquals(monitor2, result.get(1).getMonitorId()); + assertEquals(List.of(indexName2), result.get(1).getIndices()); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index 0fb9376b6..7b9d1a716 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -65,7 +65,8 @@ public void testGetFindings_success() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -186,7 +187,8 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/validators/StreamingDetectorValidatorsTests.java b/src/test/java/org/opensearch/securityanalytics/validators/StreamingDetectorValidatorsTests.java new file mode 100644 index 000000000..76e8e9c33 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/validators/StreamingDetectorValidatorsTests.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.validators; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.UUID; + +import static org.mockito.Mockito.when; + +public class StreamingDetectorValidatorsTests extends OpenSearchTestCase { + @Mock + private Detector detector; + @Mock + private DetectorInput detectorInput; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(detector.getInputs()).thenReturn(List.of(detectorInput)); + when(detector.getWorkflowIds()).thenReturn(List.of(UUID.randomUUID().toString())); + when(detector.getMonitorIds()).thenReturn(List.of(UUID.randomUUID().toString())); + } + + public void testValidDetector() { + StreamingDetectorValidators.validateDetector(detector); + } + + public void testInvalidInputsLength() { + when(detector.getInputs()).thenReturn(List.of(detectorInput, detectorInput)); + + assertThrows(SecurityAnalyticsException.class, () -> StreamingDetectorValidators.validateDetector(detector)); + } + + public void testInvalidWorkflowIdsLength() { + when(detector.getWorkflowIds()).thenReturn(List.of(UUID.randomUUID().toString(), UUID.randomUUID().toString())); + + assertThrows(SecurityAnalyticsException.class, () -> StreamingDetectorValidators.validateDetector(detector)); + } + + public void testInvalidMonitorIdsLength() { + when(detector.getMonitorIds()).thenReturn(List.of(UUID.randomUUID().toString(), UUID.randomUUID().toString())); + + assertThrows(SecurityAnalyticsException.class, () -> StreamingDetectorValidators.validateDetector(detector)); + } +}