diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 1d9d035f8..7bdd12816 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -95,11 +95,9 @@ import org.opensearch.securityanalytics.rules.backend.OSQueryBackend; import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; import org.opensearch.securityanalytics.rules.backend.QueryBackend; -import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; import org.opensearch.securityanalytics.util.DetectorIndices; -import org.opensearch.securityanalytics.util.DetectorUtils; import org.opensearch.securityanalytics.util.ExceptionChecker; import org.opensearch.securityanalytics.util.IndexUtils; import org.opensearch.securityanalytics.util.MonitorService; @@ -122,6 +120,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -827,20 +826,30 @@ private void addThreatIntelBasedDocLevelQueries(Detector detector, ActionListene */ private IndexMonitorRequest createDocLevelMonitorMatchAllRequest( Detector detector, - WriteRequest.RefreshPolicy refreshPolicy, + RefreshPolicy refreshPolicy, String monitorId, - RestRequest.Method restMethod - ) { + Method restMethod, + List> queries) { List docLevelMonitorInputs = new ArrayList<>(); List docLevelQueries = new ArrayList<>(); String monitorName = detector.getName() + "_chained_findings"; String actualQuery = "_id:*"; + Set tags = new HashSet<>(); + for (Pair query: queries) { + if(query.getRight().isAggregationRule()) { + Rule rule = query.getRight(); + tags.add(rule.getLevel()); + tags.add(rule.getCategory()); + tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList())); + } + } + tags.removeIf(Objects::isNull); DocLevelQuery docLevelQuery = new DocLevelQuery( monitorName, monitorName + "doc", Collections.emptyList(), actualQuery, - Collections.emptyList() + new ArrayList<>(tags) ); docLevelQueries.add(docLevelQuery); @@ -900,8 +909,8 @@ public void onResponse(Map> ruleFieldMappings) { @Override public void onResponse(Collection indexMonitorRequests) { // if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger - if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) { - monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST)); + if (enabledWorkflowUsage && !monitorRequests.isEmpty() && queries.stream().anyMatch(it -> it.getRight().isAggregationRule())) { + monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, queries)); } listener.onResponse(monitorRequests); } @@ -1053,7 +1062,7 @@ public void onFailure(Exception e) { listener.onFailure(e); } }); - } catch (SigmaError e) { + } catch (Exception e) { log.error("Failed to create bucket level monitor request", e); listener.onFailure(e); } diff --git a/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java b/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java index 119de62cf..14c241f83 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java +++ b/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java @@ -104,17 +104,12 @@ public void onFailure(Exception e) { }); } - public static List getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger( - Detector detector, - List> rulesById, + public static List getBucketLevelMonitorIds( List monitorResponses ) { - List aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById); return monitorResponses.stream().filter( // In the case of bucket level monitors rule id is trigger id it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType() - && !it.getMonitor().getTriggers().isEmpty() - && aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId()) ).map(IndexMonitorResponse::getId).collect(Collectors.toList()); } public static List getAggRuleIdsConfiguredToTrigger(Detector detector, List> rulesById) { diff --git a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java index 5ce495b98..fa19d9958 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java +++ b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java @@ -21,7 +21,6 @@ import org.opensearch.commons.alerting.model.ChainedMonitorFindings; import org.opensearch.commons.alerting.model.CompositeInput; import org.opensearch.commons.alerting.model.Delegate; -import org.opensearch.commons.alerting.model.Monitor.MonitorType; import org.opensearch.commons.alerting.model.Sequence; import org.opensearch.commons.alerting.model.Workflow; import org.opensearch.commons.alerting.model.Workflow.WorkflowType; @@ -34,12 +33,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; -import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger; +import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIds; /** * Alerting common clas used for workflow manipulation @@ -101,7 +99,7 @@ public void upsertWorkflow( monitorResponses.addAll(updatedMonitorResponses); } cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId(); - chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses)); + chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses)); } IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds, @@ -149,16 +147,21 @@ public void deleteWorkflow(String workflowId, ActionListener monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method, ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId) { AtomicInteger index = new AtomicInteger(); - List delegates = monitorIds.stream().map( - monitorId -> { - ChainedMonitorFindings cmf = null; - if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) { - cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null; - } - Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, cmf); - return delegate; - } - ).collect(Collectors.toList()); + List delegates = new ArrayList<>(); + ChainedMonitorFindings cmf = null; + for (String monitorId : monitorIds) { + if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) { + cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null; + } else { + Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, null); + delegates.add(delegate); + } + } + if (cmf != null) { + // Add cmf with maximum value on "index" + Delegate cmfDelegate = new Delegate(index.incrementAndGet(), cmfMonitorId, cmf); + delegates.add(cmfDelegate); + } Sequence sequence = new Sequence(delegates); CompositeInput compositeInput = new CompositeInput(sequence); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index 5e763eba8..ddb0432fd 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -7,11 +7,14 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import org.apache.http.HttpStatus; import org.apache.http.entity.StringEntity; @@ -22,6 +25,7 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.commons.alerting.model.Monitor; import org.opensearch.commons.alerting.model.action.Action; import org.opensearch.core.rest.RestStatus; import org.opensearch.search.SearchHit; @@ -36,7 +40,9 @@ import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.randomAction; +import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers; @@ -662,6 +668,132 @@ public void testAlertHistoryRollover_maxAge() throws IOException, InterruptedExc restoreAlertsFindingsIMSettings(); } + /** + * 1. Creates detector with aggregation and prepackaged rules + * (sum rule - should match docIds: 1, 2, 3; maxRule - 4, 5, 6, 7; minRule - 7) + * 2. Verifies monitor execution + * 3. Verifies alerts + * + * @throws IOException + */ + public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String infoOpCode = "Info"; + + String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); + + + List detectorRules = List.of(new DetectorRule(sumRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), + List.of(new DetectorTrigger("randomtrigegr", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of())) + ); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); // 5 for rules, 1 for match_all query in chained findings monitor + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + + List monitorIds = ((List) (updatedDetectorMap).get("monitor_id")); + + indexDoc(index, "1", randomDoc(2, 4, infoOpCode)); + indexDoc(index, "2", randomDoc(3, 4, infoOpCode)); + + Map numberOfMonitorTypes = new HashMap<>(); + + for (String monitorId : monitorIds) { + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + + // Assert monitor executions + Map executeResults = entityAsMap(executeResponse); + if (Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) { + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + assertEquals(5, noOfSigmaRuleMatches); + } + } + + assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue()); + assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + assertEquals(1, getFindingsBody.get("total_findings")); + + String findingDetectorId = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); + assertEquals(detectorId, findingDetectorId); + + String findingIndex = ((Map) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString(); + assertEquals(index, findingIndex); + + List docLevelFinding = new ArrayList<>(); + List> findings = (List) getFindingsBody.get("findings"); + + + for (Map finding : findings) { + List> queries = (List>) finding.get("queries"); + Set findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet()); + + // In the case of bucket level monitors, queries will always contain one value + String aggRuleId = findingRuleIds.iterator().next(); + List findingDocs = (List) finding.get("related_doc_ids"); + + if (aggRuleId.equals(sumRuleId)) { + assertTrue(List.of("1", "2", "3", "4", "5", "6", "7").containsAll(findingDocs)); + } + } + + assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding)); + + Map params1 = new HashMap<>(); + params1.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null); + Map getAlertsBody = asMap(getAlertsResponse); + // TODO enable asserts here when able + Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert + } public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, InterruptedException { updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "1s"); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index 89a8c0efb..04e5008d9 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -890,7 +890,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); - assertEquals(6, response.getHits().getTotalHits().value); + assertEquals(7, response.getHits().getTotalHits().value); // 6 for rules, 1 for match_all query in chained findings monitor assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); Map responseBody = asMap(createResponse); @@ -910,8 +910,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti assertEquals(6, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); List monitorIds = ((List) (updatedDetectorMap).get("monitor_id")); - - assertEquals(6, monitorIds.size()); + assertEquals(7, monitorIds.size()); indexDoc(index, "1", randomDoc(2, 4, infoOpCode)); indexDoc(index, "2", randomDoc(3, 4, infoOpCode)); @@ -952,7 +951,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti } assertEquals(5, numberOfMonitorTypes.get(MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue()); - assertEquals(1, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); + assertEquals(2, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); Map params = new HashMap<>(); params.put("detector_id", detectorId); @@ -1037,7 +1036,7 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); - assertEquals(1, response.getHits().getTotalHits().value); + assertEquals(2, response.getHits().getTotalHits().value); assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); Map responseBody = asMap(createResponse); @@ -1058,13 +1057,13 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); List monitorIds = ((List) (detectorMap).get("monitor_id")); - assertEquals(2, monitorIds.size()); + assertEquals(3, monitorIds.size()); assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); // Verify workflow - verifyWorkflow(detectorMap, monitorIds, 2); + verifyWorkflow(detectorMap, monitorIds, 3); } public void testCreateDetector_verifyWorkflowCreation_success_WithGroupByRulesInTrigger() throws IOException { @@ -1612,6 +1611,7 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve indexDoc(index, "7", randomDoc(6, 2, testOpCode)); indexDoc(index, "8", randomDoc(1, 1, testOpCode)); // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 7); String workflowId = ((List) detectorMap.get("workflow_ids")).get(0);