diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 8a907cb71..cfff7da26 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -310,7 +310,7 @@ private void getValidDocuments(String detectorType, List indices, List> filteredCorrelationRules = new ArrayList<>(); + List filteredCorrelationRules = new ArrayList<>(); int idx = 0; for (MultiSearchResponse.Item response : responses) { @@ -320,7 +320,7 @@ public void onResponse(MultiSearchResponse items) { } if (response.getResponse().getHits().getTotalHits().value > 0L) { - filteredCorrelationRules.add(Triple.of(validCorrelationRules.get(idx), + filteredCorrelationRules.add(new FilteredCorrelationRule(validCorrelationRules.get(idx), response.getResponse().getHits().getHits(), validFields.get(idx))); } ++idx; @@ -328,9 +328,9 @@ public void onResponse(MultiSearchResponse items) { Map> categoryToQueriesMap = new HashMap<>(); Map categoryToTimeWindowMap = new HashMap<>(); - for (Triple rule: filteredCorrelationRules) { - List queries = rule.getLeft().getCorrelationQueries(); - Long timeWindow = rule.getLeft().getCorrTimeWindow(); + for (FilteredCorrelationRule rule: filteredCorrelationRules) { + List queries = rule.correlationRule.getCorrelationQueries(); + Long timeWindow = rule.correlationRule.getCorrTimeWindow(); for (CorrelationQuery query: queries) { List correlationQueries; @@ -348,10 +348,10 @@ public void onResponse(MultiSearchResponse items) { if (query.getField() == null) { correlationQueries.add(query); } else { - SearchHit[] hits = rule.getMiddle(); + SearchHit[] hits = rule.filteredDocs; StringBuilder qb = new StringBuilder(query.getField()).append(":("); for (int i = 0; i < hits.length; ++i) { - String value = hits[i].field(rule.getRight()).getValue(); + String value = hits[i].field(rule.field).getValue(); qb.append(value); if (i < hits.length-1) { qb.append(" OR "); @@ -368,7 +368,7 @@ public void onResponse(MultiSearchResponse items) { } } searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap, - filteredCorrelationRules.stream().map(Triple::getLeft).map(CorrelationRule::getId).collect(Collectors.toList()), + filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), autoCorrelations ); } @@ -630,15 +630,15 @@ public DocSearchCriteria(List indices, List queries, List ENABLE_AUTO_CORRELATIONS = Setting.boolSetting( - "plugins.security_analytics.enable_auto_correlations", + "plugins.security_analytics.auto_correlations_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic ); diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index 97c192104..a07cc6eb1 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -259,7 +259,7 @@ public static String randomRule() { "level: high"; } - public static String randomRuleForCorrelations(String value) { + public static String randomCloudtrailRuleForCorrelations(String value) { return "id: 5f92fff9-82e2-48ab-8fc1-8b133556a551\n" + "logsource:\n" + " product: cloudtrail\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java index a1ae87fba..d4352a565 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics.correlation; +import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.io.entity.StringEntity; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Assert; @@ -13,8 +14,10 @@ import org.opensearch.search.SearchHit; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; +import org.opensearch.securityanalytics.TestHelpers; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; @@ -553,14 +556,14 @@ public void testBasicCorrelationEngineWorkflowWithFieldBasedRules() throws IOExc Response response = client().performRequest(createMappingRequest); assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); - String rule1 = randomRuleForCorrelations("CreateUser"); + String rule1 = randomCloudtrailRuleForCorrelations("CreateUser"); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), new StringEntity(rule1), new BasicHeader("Content-Type", "application/json")); Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); Map responseBody = asMap(createResponse); String createdId1 = responseBody.get("_id").toString(); - String rule2 = randomRuleForCorrelations("DeleteUser"); + String rule2 = randomCloudtrailRuleForCorrelations("DeleteUser"); createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), new StringEntity(rule2), new BasicHeader("Content-Type", "application/json")); Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); @@ -722,14 +725,14 @@ public void testBasicCorrelationEngineWorkflowWithFieldBasedRulesAndDynamicTimeW Response response = client().performRequest(createMappingRequest); assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); - String rule1 = randomRuleForCorrelations("CreateUser"); + String rule1 = randomCloudtrailRuleForCorrelations("CreateUser"); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), new StringEntity(rule1), new BasicHeader("Content-Type", "application/json")); Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); Map responseBody = asMap(createResponse); String createdId1 = responseBody.get("_id").toString(); - String rule2 = randomRuleForCorrelations("DeleteUser"); + String rule2 = randomCloudtrailRuleForCorrelations("DeleteUser"); createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), new StringEntity(rule2), new BasicHeader("Content-Type", "application/json")); Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); @@ -805,6 +808,114 @@ public void testBasicCorrelationEngineWorkflowWithFieldBasedRulesAndDynamicTimeW Assert.assertEquals(2, count); } + public void testBasicCorrelationEngineWorkflowWithCustomLogTypes() throws IOException, InterruptedException { + LogIndices indices = new LogIndices(); + indices.vpcFlowsIndex = createTestIndex("vpc_flow1", vpcFlowMappings()); + + String vpcFlowMonitorId = createVpcFlowDetector(indices.vpcFlowsIndex); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + CustomLogType customLogType = TestHelpers.randomCustomLogType(null, null, null, "Custom"); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.CUSTOM_LOG_TYPE_URI, Collections.emptyMap(), toHttpEntity(customLogType)); + Assert.assertEquals("Create custom log type failed", RestStatus.CREATED, restStatus(createResponse)); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + customLogType.getName() + "\", " + + " \"partial\":true, " + + " \"alias_mappings\":{}" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + + String rule = randomRule(); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", customLogType.getName()), + new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + String createdId = responseBody.get("_id").toString(); + + DetectorInput input = new DetectorInput("custom log type detector for security analytics", List.of(index), List.of(new DetectorRule(createdId)), + List.of()); + Detector detector = randomDetectorWithInputs(List.of(input), customLogType.getName()); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + createdId = responseBody.get("_id").toString(); + + String detectorTypeInResponse = (String) ((Map)responseBody.get("detector")).get("detector_type"); + Assert.assertEquals("Detector type incorrect", customLogType.getName(), detectorTypeInResponse); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + String monitorId = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + String ruleId = createNetworkToCustomLogTypeFieldBasedRule(indices, customLogType.getName(), index); + + indexDoc(index, "1", randomDoc()); + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(1, noOfSigmaRuleMatches); + + indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc()); + executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(1, noOfSigmaRuleMatches); + Thread.sleep(5000); + + Map params = new HashMap<>(); + params.put("detectorType", customLogType.getName()); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + String finding = ((List>) getFindingsBody.get("findings")).get(0).get("id").toString(); + + int count = 0; + while (true) { + try { + List> correlatedFindings = searchCorrelatedFindings(finding, customLogType.getName(), 300000L, 10); + if (correlatedFindings.size() == 1) { + Assert.assertTrue(true); + + Assert.assertTrue(correlatedFindings.get(0).get("rules") instanceof List); + + for (var correlatedFinding: correlatedFindings) { + if (correlatedFinding.get("detector_type").equals("network")) { + Assert.assertEquals(1, ((List) correlatedFinding.get("rules")).size()); + Assert.assertTrue(((List) correlatedFinding.get("rules")).contains(ruleId)); + } + } + break; + } + } catch (Exception ex) { + // suppress ex + } + ++count; + Thread.sleep(5000); + if (count >= 12) { + Assert.assertTrue(false); + break; + } + } + } + private LogIndices createIndices() throws IOException { LogIndices indices = new LogIndices(); indices.adLdapLogsIndex = createTestIndex("ad_logs", adLdapLogMappings()); @@ -828,6 +939,19 @@ private String createNetworkToWindowsFieldBasedRule(LogIndices indices) throws I return entityAsMap(response).get("_id").toString(); } + private String createNetworkToCustomLogTypeFieldBasedRule(LogIndices indices, String customLogTypeName, String customLogTypeIndex) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); + CorrelationQuery query4 = new CorrelationQuery(customLogTypeIndex, null, customLogTypeName, "SourceIp"); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to custom log type", List.of(query1, query4), 300000L); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + private String createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOException { CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "dstaddr:4.5.6.7", "network", null); CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap", null);