diff --git a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java index 5e6bd3bf7..2805e85be 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java +++ b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java @@ -348,6 +348,90 @@ public Script convertToCondition() { return new Script(condition.toString()); } + public Script convertToConditionForChainedFindings() { + StringBuilder condition = new StringBuilder(); + + boolean triggerFlag = false; + + int size = 0; + if (detectionTypes.contains(RULES_DETECTION_TYPE)) { // trigger should match rules based queries based on conditions + StringBuilder ruleTypeBuilder = new StringBuilder(); + size = ruleTypes.size(); + for (int idx = 0; idx < size; ++idx) { + ruleTypeBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleTypes.get(idx))); + if (idx < size - 1) { + ruleTypeBuilder.append(" || "); + } + } + if (size > 0) { + condition.append("(").append(ruleTypeBuilder).append(")"); + triggerFlag = true; + } + + StringBuilder ruleNameBuilder = new StringBuilder(); + size = ruleIds.size(); + for (int idx = 0; idx < size; ++idx) { + ruleNameBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleIds.get(idx))); + if (idx < size - 1) { + ruleNameBuilder.append(" || "); + } + } + if (size > 0) { + if (triggerFlag) { + condition.append(" && ").append("(").append(ruleNameBuilder).append(")"); + } else { + condition.append("(").append(ruleNameBuilder).append(")"); + triggerFlag = true; + } + } + + StringBuilder ruleSevLevelBuilder = new StringBuilder(); + size = ruleSeverityLevels.size(); + for (int idx = 0; idx < size; ++idx) { + ruleSevLevelBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleSeverityLevels.get(idx))); + if (idx < size - 1) { + ruleSevLevelBuilder.append(" || "); + } + } + + if (size > 0) { + if (triggerFlag) { + condition.append(" && ").append("(").append(ruleSevLevelBuilder).append(")"); + } else { + condition.append("(").append(ruleSevLevelBuilder).append(")"); + triggerFlag = true; + } + } + + StringBuilder tagBuilder = new StringBuilder(); + size = tags.size(); + for (int idx = 0; idx < size; ++idx) { + tagBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", tags.get(idx))); + if (idx < size - 1) { + ruleSevLevelBuilder.append(" || "); + } + } + + if (size > 0) { + if (triggerFlag) { + condition.append(" && ").append("(").append(tagBuilder).append(")"); + } else { + condition.append("(").append(tagBuilder).append(")"); + } + } + } + if(detectionTypes.contains(THREAT_INTEL_DETECTION_TYPE)) { + StringBuilder threatIntelClauseBuilder = new StringBuilder(); + threatIntelClauseBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", "threat_intel")); + if (condition.length() > 0) { + condition.append(" || "); + } + condition.append("(").append(threatIntelClauseBuilder).append(")"); + } + + return new Script(condition.toString()); + } + public String getId() { return id; } diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 8f4c4f1fd..aaf64b6ff 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -859,6 +859,7 @@ private IndexMonitorRequest createDocLevelMonitorMatchAllRequest( if(query.getRight().isAggregationRule()) { Rule rule = query.getRight(); tags.add(rule.getLevel()); + tags.add(rule.getId()); tags.add(rule.getCategory()); tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList())); } @@ -888,7 +889,7 @@ private IndexMonitorRequest createDocLevelMonitorMatchAllRequest( String name = detectorTrigger.getName(); String severity = detectorTrigger.getSeverity(); List actions = detectorTrigger.getActions(); - Script condition = detectorTrigger.convertToCondition(); + Script condition = detectorTrigger.convertToConditionForChainedFindings(); triggers.add(new DocumentLevelTrigger(id, name, severity, actions, condition)); } diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index fe97a13be..7eecaa3b7 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -959,6 +959,125 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException Assert.assertEquals(2, getAlertsBody.get("total_alerts")); } + public void test_detectorWith1AggRuleAndTriggeronRule_updateWithSecondAggRule() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(org.apache.http.HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String infoOpCode = "Info"; + /** 1st agg rule*/ + String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); + + + List detectorRules = List.of(new DetectorRule(sumRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), + List.of(new DetectorTrigger("randomtrigegr", "test-trigger", "1", List.of(randomDetectorType()), List.of(sumRuleId), List.of(), List.of(), List.of(), List.of())) + ); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()) + "*", request, true); + + assertEquals(1, response.getHits().getTotalHits().value); // 5 for rules, 1 for match_all query in chained findings monitor + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + + String workflowId = ((List) (updatedDetectorMap).get("workflow_ids")).get(0); + + indexDoc(index, "1", randomDoc(2, 4, infoOpCode)); + indexDoc(index, "2", randomDoc(3, 4, infoOpCode)); + executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + /** assert findings */ + assertNotNull(getFindingsBody); + assertEquals(1, getFindingsBody.get("total_findings")); + + /**assert alerts */ + Map params1 = new HashMap<>(); + params1.put("detector_id", detectorId); + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null); + Map getAlertsBody = asMap(getAlertsResponse); + + Assert.assertEquals(1, getAlertsBody.get("total_alerts")); + /** 2nd agg rule*/ + String sumRuleId2 = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); + String sumRuleId3 = createRule(randomAggregationRule("sum", " > 100", infoOpCode)); + + detectorRules = List.of(new DetectorRule(sumRuleId), new DetectorRule(sumRuleId2)); + input = new DetectorInput("updated", List.of("windows"), detectorRules, Collections.emptyList()); + Detector updatedDetector = randomDetectorWithInputsAndTriggers(List.of(input), + List.of(new DetectorTrigger("updated1", "test-trigger1", "1", List.of(randomDetectorType()), List.of(sumRuleId2, sumRuleId), List.of(), List.of(), List.of(), List.of()), + new DetectorTrigger("updated2", "test-trigger2", "1", List.of(randomDetectorType()), List.of(sumRuleId2, sumRuleId3), List.of(), List.of(), List.of(), List.of()), + new DetectorTrigger("noAlertsExpected", "test-trigger2", "1", List.of(randomDetectorType()), List.of(sumRuleId3), List.of(), List.of(), List.of(), List.of())) + ); + /** update detector and verify chained findings monitor should still exist*/ + makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + + assertEquals(3, ((List) (updatedDetectorMap).get("monitor_id")).size()); + indexDoc(index, "3", randomDoc(2, 5, infoOpCode)); + indexDoc(index, "4", randomDoc(3, 5, infoOpCode)); + + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + + workflowId = ((List) (updatedDetectorMap).get("workflow_ids")).get(0); + executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + params = new HashMap<>(); + params.put("detector_id", detectorId); + getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + assertEquals(3, getFindingsBody.get("total_findings")); + + params1 = new HashMap<>(); + params1.put("detector_id", detectorId); + getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null); + getAlertsBody = asMap(getAlertsResponse); + Assert.assertEquals(3, getAlertsBody.get("total_alerts")); + } + @Ignore public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, InterruptedException { updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "1s");