From 4bfa45f1e261652b2e4821db86e0c032952741df Mon Sep 17 00:00:00 2001 From: Riya <69919272+riysaxen-amzn@users.noreply.github.com> Date: Mon, 1 Jul 2024 19:14:21 -0700 Subject: [PATCH] added correlationAlert integ tests (#1099) * added correlationAlert integ tests Signed-off-by: Riya Saxena * added licences Signed-off-by: Riya Saxena * fixed imports Signed-off-by: Riya Saxena * deleted SecureCorrelationAlerts Tests, will add later Signed-off-by: Riya Saxena --------- Signed-off-by: Riya Saxena (cherry picked from commit e8d78790ffbf344a68efd05417c8af2469948f57) --- .../SecurityAnalyticsRestTestCase.java | 360 +++++++++++++++++- .../securityanalytics/TestHelpers.java | 12 + .../CorrelationEngineRestApiIT.java | 303 --------------- .../CorrelationEngineRuleRestApiIT.java | 40 ++ .../alerts/CorrelationAlertServiceTests.java | 79 ++++ .../alerts/CorrelationAlertsRestApiIT.java | 284 ++++++++++++++ 6 files changed, 765 insertions(+), 313 deletions(-) create mode 100644 src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertServiceTests.java create mode 100644 src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertsRestApiIT.java diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index 59e4ba9f0..024f43e4f 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -26,6 +26,7 @@ import org.opensearch.client.WarningsHandler; import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.common.Strings; import org.opensearch.common.UUIDs; import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Settings; @@ -33,6 +34,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.alerting.model.action.Action; import org.opensearch.commons.alerting.model.ScheduledJob; import org.opensearch.commons.alerting.util.IndexUtilsKt; import org.opensearch.commons.rest.SecureRestClientBuilder; @@ -53,16 +55,21 @@ import org.opensearch.securityanalytics.action.CreateIndexMappingsRequest; import org.opensearch.securityanalytics.action.UpdateIndexMappingsRequest; import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; +import org.opensearch.securityanalytics.correlation.CorrelationEngineRestApiIT; import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryBuilder; import org.opensearch.securityanalytics.mapper.MappingsTraverser; +import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.securityanalytics.util.CorrelationIndices; import org.opensearch.test.rest.OpenSearchRestTestCase; - import javax.management.MBeanServerInvocationHandler; import javax.management.MalformedObjectNameException; import javax.management.ObjectName; @@ -84,11 +91,20 @@ import java.util.Set; import java.util.function.BiConsumer; import java.util.stream.Collectors; - import static org.opensearch.action.admin.indices.create.CreateIndexRequest.MAPPINGS; import static org.opensearch.securityanalytics.SecurityAnalyticsPlugin.MAPPER_BASE_URI; import static org.opensearch.securityanalytics.TestHelpers.productIndexAvgAggRule; import static org.opensearch.securityanalytics.TestHelpers.sumAggregationTestRule; +import static org.opensearch.securityanalytics.TestHelpers.adLdapLogMappings; +import static org.opensearch.securityanalytics.TestHelpers.appLogMappings; +import static org.opensearch.securityanalytics.TestHelpers.productIndexAvgAggRule; +import static org.opensearch.securityanalytics.TestHelpers.randomIndex; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggersAndType; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.sumAggregationTestRule; +import static org.opensearch.securityanalytics.TestHelpers.s3AccessLogMappings; +import static org.opensearch.securityanalytics.TestHelpers.vpcFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_INDEX_MAX_AGE; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_MAX_DOCS; @@ -1360,13 +1376,6 @@ protected void enableOrDisableFilterBy(String trueOrFalse) throws IOException { client().performRequest(request); } - protected void createUserWithDataAndCustomRole(String userName, String userPasswd, String roleName, String[] backendRoles, String clusterPermissions ) throws IOException { - String[] users = {userName}; - createUser(userName, backendRoles); - createCustomRole(roleName, clusterPermissions); - createUserRolesMapping(roleName, users); - } - protected void createUserWithDataAndCustomRole(String userName, String userPasswd, String roleName, String[] backendRoles, List clusterPermissions, List indexPermissions, List indexPatterns) throws IOException { String[] users = {userName}; createUser(userName, backendRoles); @@ -1792,6 +1801,329 @@ public String getMatchAllSearchRequestString(int num) { "}"; } + protected CorrelationEngineRestApiIT.LogIndices createIndices() throws IOException { + CorrelationEngineRestApiIT.LogIndices indices = new CorrelationEngineRestApiIT.LogIndices(); + indices.adLdapLogsIndex = createTestIndex("ad_logs", adLdapLogMappings()); + indices.s3AccessLogsIndex = createTestIndex("s3_access_logs", s3AccessLogMappings()); + indices.appLogsIndex = createTestIndex("app_logs", appLogMappings()); + indices.windowsIndex = createTestIndex(randomIndex(), windowsIndexMapping()); + indices.vpcFlowsIndex = createTestIndex("vpc_flow", vpcFlowMappings()); + return indices; + } + + protected String createNetworkToWindowsFieldBasedRule(CorrelationEngineRestApiIT.LogIndices indices) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); + CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, null, "test_windows", "SourceIp"); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L, null); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createNetworkToWindowsFilterQueryBasedRule(LogIndices indices) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "srcaddr:1.2.3.4", "network", null); + CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "SourceIp:1.2.3.4", "test_windows", null); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L, null); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createNetworkToCustomLogTypeFieldBasedRule(LogIndices indices, String customLogTypeName, String customLogTypeIndex) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); + CorrelationQuery query4 = new CorrelationQuery(customLogTypeIndex, null, customLogTypeName, "SourceIp"); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to custom log type", List.of(query1, query4), 300000L, null); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "dstaddr:4.5.6.7", "network", null); + CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap", null); + CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "Domain:NTAUTHORI*", "test_windows", null); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to ad_ldap to windows", List.of(query1, query2, query4), 300000L, null); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createNetworkToAdLdapToWindowsRuleWithTrigger(LogIndices indices) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "dstaddr:4.5.6.7", "network", null); + CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap", null); + CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "Domain:NTAUTHORI*", "test_windows", null); + List actions = new ArrayList<>(); + CorrelationRuleTrigger trigger = new CorrelationRuleTrigger("trigger-123", "Trigger 1", "high", actions); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to ad_ldap to windows", List.of(query1, query2, query4), 300000L, trigger); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(indices.windowsIndex, "HostName:EC2AMAZ*", "test_windows", null); + CorrelationQuery query2 = new CorrelationQuery(indices.appLogsIndex, "endpoint:\\/customer_records.txt", "others_application", null); + CorrelationQuery query4 = new CorrelationQuery(indices.s3AccessLogsIndex, "aws.cloudtrail.eventName:ReplicateObject", "s3", null); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "windows to app_logs to s3 logs", List.of(query1, query2, query4), 300000L, null); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createCloudtrailFieldBasedRule(String index, String field, Long timeWindow) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(index, "EventName:CreateUser", "cloudtrail", field); + CorrelationQuery query2 = new CorrelationQuery(index, "EventName:DeleteUser", "cloudtrail", field); + + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "cloudtrail field based", List.of(query1, query2), timeWindow, null); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + protected String createCloudtrailFieldBasedRuleWithTrigger(String index, String field, Long timeWindow) throws IOException { + CorrelationQuery query1 = new CorrelationQuery(index, "EventName:CreateUser", "cloudtrail", field); + CorrelationQuery query2 = new CorrelationQuery(index, "EventName:DeleteUser", "cloudtrail", field); + List actions = new ArrayList<>(); + CorrelationRuleTrigger trigger = new CorrelationRuleTrigger("trigger-345", "Trigger 2", "high", actions); + CorrelationRule rule = new CorrelationRule("correlation-rule-1", CorrelationRule.NO_VERSION, "cloudtrail field based", List.of(query1, query2), timeWindow, trigger); + Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); + request.setJsonEntity(toJsonString(rule)); + Response response = client().performRequest(request); + + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + return entityAsMap(response).get("_id").toString(); + } + + @SuppressWarnings("unchecked") + protected String createVpcFlowDetector(String indexName) throws IOException { + Detector vpcFlowDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("vpc flow detector for security analytics", List.of(indexName), List.of(), + getPrePackagedRules("network").stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of(), List.of())), "network"); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(vpcFlowDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + } + + @SuppressWarnings("unchecked") + protected String createAdLdapDetector(String indexName) throws IOException { + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{\n" + + " \"index_name\": \"" + indexName + "\",\n" + + " \"rule_topic\": \"ad_ldap\",\n" + + " \"partial\": true,\n" + + " \"alias_mappings\": {\n" + + " \"properties\": {\n" + + " \"azure.signinlogs.properties.user_id\": {\n" + + " \"path\": \"azure.signinlogs.props.user_id\",\n" + + " \"type\": \"alias\"\n" + + " },\n" + + " \"azure-platformlogs-result_type\": {\n" + + " \"path\": \"azure.platformlogs.result_type\",\n" + + " \"type\": \"alias\"\n" + + " },\n" + + " \"azure-signinlogs-result_description\": {\n" + + " \"path\": \"azure.signinlogs.result_description\",\n" + + " \"type\": \"alias\"\n" + + " },\n" + + " \"timestamp\": {\n" + + " \"path\": \"creationTime\",\n" + + " \"type\": \"alias\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); + + Detector adLdapDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("ad_ldap logs detector for security analytics", List.of(indexName), List.of(), + getPrePackagedRules("ad_ldap").stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("ad_ldap"), List.of(), List.of(), List.of(), List.of(), List.of())), "ad_ldap"); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(adLdapDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + } + + @SuppressWarnings("unchecked") + protected String createTestWindowsDetector(String indexName) throws IOException { + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + indexName + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); + + Detector windowsDetector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of(indexName), List.of(), + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(windowsDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + } + + @SuppressWarnings("unchecked") + protected String createAppLogsDetector(String indexName) throws IOException { + Detector appLogsDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("app logs detector for security analytics", List.of(indexName), List.of(), + getPrePackagedRules("others_application").stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("others_application"), List.of(), List.of(), List.of(), List.of(), List.of())), "others_application"); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(appLogsDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + } + + @SuppressWarnings("unchecked") + protected String createS3Detector(String indexName) throws IOException { + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{\n" + + " \"index_name\": \"s3_access_logs\",\n" + + " \"rule_topic\": \"s3\",\n" + + " \"partial\": true,\n" + + " \"alias_mappings\": {\n" + + " \"properties\": {\n" + + " \"aws-cloudtrail-event_source\": {\n" + + " \"type\": \"alias\",\n" + + " \"path\": \"aws.cloudtrail.event_source\"\n" + + " },\n" + + " \"aws.cloudtrail.event_name\": {\n" + + " \"type\": \"alias\",\n" + + " \"path\": \"aws.cloudtrail.event_name\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); + + Detector s3AccessLogsDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("s3 access logs detector for security analytics", List.of(indexName), List.of(), + getPrePackagedRules("s3").stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("s3"), List.of(), List.of(), List.of(), List.of(), List.of())), "s3"); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(s3AccessLogsDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + } + /** * We need to be able to dump the jacoco coverage before cluster is shut down. * The new internal testing framework removed some of the gradle tasks we were listening to @@ -1830,4 +2162,12 @@ public static void dumpCoverage() throws IOException, MalformedObjectNameExcepti throw new RuntimeException("Failed to dump coverage: " + ex); } } -} \ No newline at end of file + + public static class LogIndices { + public String vpcFlowsIndex; + public String adLdapLogsIndex; + public String windowsIndex; + public String appLogsIndex; + public String s3AccessLogsIndex; + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index ad50a0dde..2d3519832 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -23,6 +23,7 @@ import org.opensearch.script.ScriptType; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; @@ -230,6 +231,17 @@ public static CorrelationRule randomCorrelationRule(String name) { ), 300000L, null); } + public static CorrelationRule randomCorrelationRuleWithTrigger(String name) { + name = name.isEmpty()? ">": name; + List actions = new ArrayList(); + CorrelationRuleTrigger trigger = new CorrelationRuleTrigger("trigger-123", "Trigger 1", "high", actions); + return new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, name, + List.of( + new CorrelationQuery("vpc_flow1", "dstaddr:192.168.1.*", "network", null), + new CorrelationQuery("ad_logs1", "azure.platformlogs.result_type:50126", "ad_ldap", null) + ), 300000L, trigger); + } + public static String randomRule() { return "title: Remote Encrypting File System Abuse\n" + "id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java index f05092d01..b2d56c1f5 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java @@ -15,8 +15,6 @@ import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; import org.opensearch.securityanalytics.TestHelpers; -import org.opensearch.securityanalytics.model.CorrelationQuery; -import org.opensearch.securityanalytics.model.CorrelationRule; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; @@ -32,7 +30,6 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.BooleanSupplier; -import java.util.stream.Collectors; import static org.opensearch.securityanalytics.TestHelpers.*; @@ -954,304 +951,4 @@ public void testBasicCorrelationEngineWorkflowWithCustomLogTypes() throws IOExce ); } - private LogIndices createIndices() throws IOException { - LogIndices indices = new LogIndices(); - indices.adLdapLogsIndex = createTestIndex("ad_logs", adLdapLogMappings()); - indices.s3AccessLogsIndex = createTestIndex("s3_access_logs", s3AccessLogMappings()); - indices.appLogsIndex = createTestIndex("app_logs", appLogMappings()); - indices.windowsIndex = createTestIndex(randomIndex(), windowsIndexMapping()); - indices.vpcFlowsIndex = createTestIndex("vpc_flow", vpcFlowMappings()); - return indices; - } - - private String createNetworkToWindowsFieldBasedRule(LogIndices indices) throws IOException { - CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); - CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, null, "test_windows", "SourceIp"); - - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L, null); - Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); - request.setJsonEntity(toJsonString(rule)); - Response response = client().performRequest(request); - - Assert.assertEquals(201, response.getStatusLine().getStatusCode()); - return entityAsMap(response).get("_id").toString(); - } - - private String createNetworkToWindowsFilterQueryBasedRule(LogIndices indices) throws IOException { - CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "srcaddr:1.2.3.4", "network", null); - CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "SourceIp:1.2.3.4", "test_windows", null); - - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L, null); - Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); - request.setJsonEntity(toJsonString(rule)); - Response response = client().performRequest(request); - - Assert.assertEquals(201, response.getStatusLine().getStatusCode()); - return entityAsMap(response).get("_id").toString(); - } - - private String createNetworkToCustomLogTypeFieldBasedRule(LogIndices indices, String customLogTypeName, String customLogTypeIndex) throws IOException { - CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); - CorrelationQuery query4 = new CorrelationQuery(customLogTypeIndex, null, customLogTypeName, "SourceIp"); - - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to custom log type", List.of(query1, query4), 300000L, null); - Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); - request.setJsonEntity(toJsonString(rule)); - Response response = client().performRequest(request); - - Assert.assertEquals(201, response.getStatusLine().getStatusCode()); - return entityAsMap(response).get("_id").toString(); - } - - private String createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOException { - CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "dstaddr:4.5.6.7", "network", null); - CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap", null); - CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "Domain:NTAUTHORI*", "test_windows", null); - - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to ad_ldap to windows", List.of(query1, query2, query4), 300000L, null); - Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); - request.setJsonEntity(toJsonString(rule)); - Response response = client().performRequest(request); - - Assert.assertEquals(201, response.getStatusLine().getStatusCode()); - return entityAsMap(response).get("_id").toString(); - } - - private String createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOException { - CorrelationQuery query1 = new CorrelationQuery(indices.windowsIndex, "HostName:EC2AMAZ*", "test_windows", null); - CorrelationQuery query2 = new CorrelationQuery(indices.appLogsIndex, "endpoint:\\/customer_records.txt", "others_application", null); - CorrelationQuery query4 = new CorrelationQuery(indices.s3AccessLogsIndex, "aws.cloudtrail.eventName:ReplicateObject", "s3", null); - - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "windows to app_logs to s3 logs", List.of(query1, query2, query4), 300000L, null); - Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); - request.setJsonEntity(toJsonString(rule)); - Response response = client().performRequest(request); - - Assert.assertEquals(201, response.getStatusLine().getStatusCode()); - return entityAsMap(response).get("_id").toString(); - } - - private String createCloudtrailFieldBasedRule(String index, String field, Long timeWindow) throws IOException { - CorrelationQuery query1 = new CorrelationQuery(index, "EventName:CreateUser", "cloudtrail", field); - CorrelationQuery query2 = new CorrelationQuery(index, "EventName:DeleteUser", "cloudtrail", field); - - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "cloudtrail field based", List.of(query1, query2), timeWindow, null); - Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); - request.setJsonEntity(toJsonString(rule)); - Response response = client().performRequest(request); - - Assert.assertEquals(201, response.getStatusLine().getStatusCode()); - return entityAsMap(response).get("_id").toString(); - } - - @SuppressWarnings("unchecked") - private String createVpcFlowDetector(String indexName) throws IOException { - Detector vpcFlowDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("vpc flow detector for security analytics", List.of(indexName), List.of(), - getPrePackagedRules("network").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("network"), List.of(), List.of(), List.of(), List.of(), List.of())), "network"); - - Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(vpcFlowDetector)); - Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); - - Map responseBody = asMap(createResponse); - - String createdId = responseBody.get("_id").toString(); - - String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + createdId + "\"\n" + - " }\n" + - " }\n" + - "}"; - List hits = executeSearch(Detector.DETECTORS_INDEX, request); - SearchHit hit = hits.get(0); - - return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); - } - - @SuppressWarnings("unchecked") - private String createAdLdapDetector(String indexName) throws IOException { - // Execute CreateMappingsAction to add alias mapping for index - Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); - // both req params and req body are supported - createMappingRequest.setJsonEntity( - "{\n" + - " \"index_name\": \"" + indexName + "\",\n" + - " \"rule_topic\": \"ad_ldap\",\n" + - " \"partial\": true,\n" + - " \"alias_mappings\": {\n" + - " \"properties\": {\n" + - " \"azure.signinlogs.properties.user_id\": {\n" + - " \"path\": \"azure.signinlogs.props.user_id\",\n" + - " \"type\": \"alias\"\n" + - " },\n" + - " \"azure-platformlogs-result_type\": {\n" + - " \"path\": \"azure.platformlogs.result_type\",\n" + - " \"type\": \"alias\"\n" + - " },\n" + - " \"azure-signinlogs-result_description\": {\n" + - " \"path\": \"azure.signinlogs.result_description\",\n" + - " \"type\": \"alias\"\n" + - " },\n" + - " \"timestamp\": {\n" + - " \"path\": \"creationTime\",\n" + - " \"type\": \"alias\"\n" + - " }\n" + - " }\n" + - " }\n" + - "}" - ); - - Response response = client().performRequest(createMappingRequest); - assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); - - Detector adLdapDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("ad_ldap logs detector for security analytics", List.of(indexName), List.of(), - getPrePackagedRules("ad_ldap").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("ad_ldap"), List.of(), List.of(), List.of(), List.of(), List.of())), "ad_ldap"); - - Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(adLdapDetector)); - Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); - - Map responseBody = asMap(createResponse); - - String createdId = responseBody.get("_id").toString(); - - String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + createdId + "\"\n" + - " }\n" + - " }\n" + - "}"; - List hits = executeSearch(Detector.DETECTORS_INDEX, request); - SearchHit hit = hits.get(0); - - return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); - } - - @SuppressWarnings("unchecked") - private String createTestWindowsDetector(String indexName) throws IOException { - // Execute CreateMappingsAction to add alias mapping for index - Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); - // both req params and req body are supported - createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + indexName + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" - ); - - Response response = client().performRequest(createMappingRequest); - assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); - - Detector windowsDetector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of(indexName), List.of(), - getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))); - - Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(windowsDetector)); - Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); - - Map responseBody = asMap(createResponse); - - String createdId = responseBody.get("_id").toString(); - - String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + createdId + "\"\n" + - " }\n" + - " }\n" + - "}"; - List hits = executeSearch(Detector.DETECTORS_INDEX, request); - SearchHit hit = hits.get(0); - - return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); - } - - @SuppressWarnings("unchecked") - private String createAppLogsDetector(String indexName) throws IOException { - Detector appLogsDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("app logs detector for security analytics", List.of(indexName), List.of(), - getPrePackagedRules("others_application").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("others_application"), List.of(), List.of(), List.of(), List.of(), List.of())), "others_application"); - - Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(appLogsDetector)); - Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); - - Map responseBody = asMap(createResponse); - - String createdId = responseBody.get("_id").toString(); - - String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + createdId + "\"\n" + - " }\n" + - " }\n" + - "}"; - List hits = executeSearch(Detector.DETECTORS_INDEX, request); - SearchHit hit = hits.get(0); - - return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); - } - - @SuppressWarnings("unchecked") - private String createS3Detector(String indexName) throws IOException { - // Execute CreateMappingsAction to add alias mapping for index - Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); - // both req params and req body are supported - createMappingRequest.setJsonEntity( - "{\n" + - " \"index_name\": \"s3_access_logs\",\n" + - " \"rule_topic\": \"s3\",\n" + - " \"partial\": true,\n" + - " \"alias_mappings\": {\n" + - " \"properties\": {\n" + - " \"aws-cloudtrail-event_source\": {\n" + - " \"type\": \"alias\",\n" + - " \"path\": \"aws.cloudtrail.event_source\"\n" + - " },\n" + - " \"aws.cloudtrail.event_name\": {\n" + - " \"type\": \"alias\",\n" + - " \"path\": \"aws.cloudtrail.event_name\"\n" + - " }\n" + - " }\n" + - " }\n" + - "}" - ); - - Response response = client().performRequest(createMappingRequest); - assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); - - Detector s3AccessLogsDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("s3 access logs detector for security analytics", List.of(indexName), List.of(), - getPrePackagedRules("s3").stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("s3"), List.of(), List.of(), List.of(), List.of(), List.of())), "s3"); - - Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(s3AccessLogsDetector)); - Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); - - Map responseBody = asMap(createResponse); - - String createdId = responseBody.get("_id").toString(); - - String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + createdId + "\"\n" + - " }\n" + - " }\n" + - "}"; - List hits = executeSearch(Detector.DETECTORS_INDEX, request); - SearchHit hit = hits.get(0); - - return ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); - } - - static class LogIndices { - String vpcFlowsIndex; - String adLdapLogsIndex; - String windowsIndex; - String appLogsIndex; - String s3AccessLogsIndex; - } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRuleRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRuleRestApiIT.java index d8cdcfdc5..4694fe523 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRuleRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRuleRestApiIT.java @@ -15,9 +15,11 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import static org.opensearch.securityanalytics.TestHelpers.randomCorrelationRule; +import static org.opensearch.securityanalytics.TestHelpers.randomCorrelationRuleWithTrigger; public class CorrelationEngineRuleRestApiIT extends SecurityAnalyticsRestTestCase { @@ -113,4 +115,42 @@ public void testSearchCorrelationRule() throws IOException { responseMap = responseAsMap(response); Assert.assertEquals(1, Integer.parseInt(((Map) ((Map) responseMap.get("hits")).get("total")).get("value").toString())); } + + public void testSearchCorrelationRuleWithTrigger() throws IOException { + CorrelationRule rule = randomCorrelationRuleWithTrigger("custom-rule"); + Response response = makeRequest(client(), "POST", SecurityAnalyticsPlugin.CORRELATION_RULES_BASE_URI, Collections.emptyMap(), toHttpEntity(rule)); + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + Map responseMap = responseAsMap(response); + Assert.assertEquals("custom-rule", ((Map) responseMap.get("rule")).get("name")); + + String request = "{\n" + + " \"query\": {\n" + + " \"nested\": {\n" + + " \"path\": \"correlate\",\n" + + " \"query\": {\n" + + " \"bool\": {\n" + + " \"must\": [\n" + + " { \"match\": {\"correlate.category\": \"network\"}}\n" + + " ]\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + response = makeRequest(client(), "POST", SecurityAnalyticsPlugin.CORRELATION_RULES_BASE_URI + "/_search", Collections.emptyMap(), new StringEntity(request), new BasicHeader("Content-type", "application/json")); + responseMap = responseAsMap(response); + // Assuming the hits contain the matched documents + Map hits = (Map) responseMap.get("hits"); + Assert.assertNotNull(hits); + + List> hitsList = (List>) hits.get("hits"); + Assert.assertEquals(1, hitsList.size()); // Assuming you expect exactly one hit + + Map hit = hitsList.get(0); + Map source = (Map) hit.get("_source"); + Assert.assertNotNull(source); + + Object trigger = source.get("trigger"); + Assert.assertNotNull(trigger); + } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertServiceTests.java b/src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertServiceTests.java new file mode 100644 index 000000000..6a8ea14b3 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertServiceTests.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.correlation.alerts; + +import org.opensearch.client.Client; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertsList; +import org.opensearch.test.OpenSearchTestCase; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + +public class CorrelationAlertServiceTests extends OpenSearchTestCase { + + public void testGetActiveAlerts() { + // Mock setup + Client client = mock(Client.class); + NamedXContentRegistry xContentRegistry = mock(NamedXContentRegistry.class); + CorrelationAlertService alertsService = spy(new CorrelationAlertService(client, xContentRegistry)); + + + // Fake data + String ruleId = "correlation_rule_id_123"; + long currentTime = System.currentTimeMillis(); + + // Define a fake correlation alert + CorrelationAlert correlationAlert = new CorrelationAlert( + Collections.emptyList(), + ruleId, + "mock-rule", + UUID.randomUUID().toString(), + 1L, + 1, + null, + "mock-trigger", + Alert.State.ACTIVE, + Instant.ofEpochMilli(currentTime).minusMillis(1000L), + Instant.ofEpochMilli(currentTime).plusMillis(1000L), + null, + null, + "high", + new ArrayList<>() + ); + + List correlationAlerts = Collections.singletonList(correlationAlert); + + // Call getActiveAlerts + alertsService.getActiveAlerts(ruleId, currentTime, new ActionListener() { + @Override + public void onResponse(CorrelationAlertsList correlationAlertsList) { + // Assertion + assertEquals(correlationAlerts.size(), correlationAlertsList.getCorrelationAlertList().size()); + + // Additional assertions can be added here to verify specific fields or states + CorrelationAlert returnedAlert = correlationAlertsList.getCorrelationAlertList().get(0); + assertEquals(correlationAlert.getId(), returnedAlert.getId()); + assertEquals(correlationAlert.getCorrelationRuleId(), returnedAlert.getCorrelationRuleId()); + assertEquals(correlationAlert.getStartTime(), returnedAlert.getStartTime()); + assertEquals(correlationAlert.getEndTime(), returnedAlert.getEndTime()); + } + + @Override + public void onFailure(Exception e) { + + } + }); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertsRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertsRestApiIT.java new file mode 100644 index 000000000..6ff00926d --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/correlation/alerts/CorrelationAlertsRestApiIT.java @@ -0,0 +1,284 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.correlation.alerts; + +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.DetectorRule; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import static org.opensearch.securityanalytics.TestHelpers.cloudtrailMappings; +import static org.opensearch.securityanalytics.TestHelpers.randomCloudtrailDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomCloudtrailRuleForCorrelations; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggersAndType; +import static org.opensearch.securityanalytics.TestHelpers.randomDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomVpcFlowDoc; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; + + +public class CorrelationAlertsRestApiIT extends SecurityAnalyticsRestTestCase { + + public void testGetCorrelationAlertsAPI() throws IOException, InterruptedException { + LogIndices indices = createIndices(); + + String vpcFlowMonitorId = createVpcFlowDetector(indices.vpcFlowsIndex); + String testWindowsMonitorId = createTestWindowsDetector(indices.windowsIndex); + + createNetworkToAdLdapToWindowsRuleWithTrigger(indices); + Thread.sleep(5000); + + indexDoc(indices.windowsIndex, "2", randomDoc()); + Response executeResponse = executeAlertingMonitor(testWindowsMonitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(5, noOfSigmaRuleMatches); + + Thread.sleep(5000); + indexDoc(indices.vpcFlowsIndex, "1", randomVpcFlowDoc()); + executeResponse = executeAlertingMonitor(vpcFlowMonitorId, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(1, noOfSigmaRuleMatches); + + Thread.sleep(5000); + + OpenSearchRestTestCase.waitUntil( + () -> { + try { + Long endTime = System.currentTimeMillis(); + Request request = new Request("GET", "/_plugins/_security_analytics/correlationAlerts"); + Response response = client().performRequest(request); + + Map responseMap = entityAsMap(response); + List correlationAlerts = (List) responseMap.get("correlationAlerts"); + if (correlationAlerts.size() == 1) { + Assert.assertEquals(correlationAlerts.get(0).getTriggerName(), "Trigger 1"); + Assert.assertTrue(true); + return true; + } + return false; + } catch (Exception ex) { + return false; + } + }, + 2, TimeUnit.MINUTES + ); + } + + public void testGetCorrelationAlertsByRuleIdAPI() throws IOException, InterruptedException { + String index = createTestIndex("cloudtrail", cloudtrailMappings()); + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{\n" + + " \"index_name\": \"" + index + "\",\n" + + " \"rule_topic\": \"cloudtrail\",\n" + + " \"partial\": true,\n" + + " \"alias_mappings\": {\n" + + " \"properties\": {\n" + + " \"aws.cloudtrail.event_name\": {\n" + + " \"path\": \"Records.eventName\",\n" + + " \"type\": \"alias\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); + + String rule1 = randomCloudtrailRuleForCorrelations("CreateUser"); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), + new StringEntity(rule1), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + String createdId1 = responseBody.get("_id").toString(); + + String rule2 = randomCloudtrailRuleForCorrelations("DeleteUser"); + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), + new StringEntity(rule2), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + responseBody = asMap(createResponse); + String createdId2 = responseBody.get("_id").toString(); + + createCloudtrailFieldBasedRuleWithTrigger(index, "requestParameters.userName", null); + + Detector cloudtrailDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("cloudtrail detector for security analytics", List.of(index), + List.of(new DetectorRule(createdId1), new DetectorRule(createdId2)), + List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("cloudtrail"), List.of(), List.of(), List.of(), List.of(), List.of())), "cloudtrail"); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(cloudtrailDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + String monitorId = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + + indexDoc(index, "1", randomCloudtrailDoc("Richard", "CreateUser")); + executeAlertingMonitor(monitorId, Collections.emptyMap()); + Thread.sleep(1000); + indexDoc(index, "4", randomCloudtrailDoc("deysubho", "CreateUser")); + executeAlertingMonitor(monitorId, Collections.emptyMap()); + Thread.sleep(1000); + + indexDoc(index, "2", randomCloudtrailDoc("Richard", "DeleteUser")); + executeAlertingMonitor(monitorId, Collections.emptyMap()); + + Thread.sleep(5000); + + OpenSearchRestTestCase.waitUntil( + () -> { + try { + Request restRequest = new Request("GET", "/_plugins/_security_analytics/correlationAlerts?correlation_rule_id=correlation-rule-1"); + Response restResponse = client().performRequest(restRequest); + + Map responseMap = entityAsMap(restResponse); + int totalAlerts = (int) responseMap.get("total_alerts"); + if (totalAlerts == 1) { + Assert.assertTrue(true); + return true; + } + return false; + } catch (Exception ex) { + return false; + } + }, + 2, TimeUnit.MINUTES + ); + } + + public void testGetCorrelationAlertsAcknowledgeAPI() throws IOException, InterruptedException { + String index = createTestIndex("cloudtrail", cloudtrailMappings()); + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{\n" + + " \"index_name\": \"" + index + "\",\n" + + " \"rule_topic\": \"cloudtrail\",\n" + + " \"partial\": true,\n" + + " \"alias_mappings\": {\n" + + " \"properties\": {\n" + + " \"aws.cloudtrail.event_name\": {\n" + + " \"path\": \"Records.eventName\",\n" + + " \"type\": \"alias\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode()); + + String rule1 = randomCloudtrailRuleForCorrelations("CreateUser"); + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), + new StringEntity(rule1), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + String createdId1 = responseBody.get("_id").toString(); + + String rule2 = randomCloudtrailRuleForCorrelations("DeleteUser"); + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"), + new StringEntity(rule2), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + responseBody = asMap(createResponse); + String createdId2 = responseBody.get("_id").toString(); + + createCloudtrailFieldBasedRuleWithTrigger(index, "requestParameters.userName", null); + + Detector cloudtrailDetector = randomDetectorWithInputsAndTriggersAndType(List.of(new DetectorInput("cloudtrail detector for security analytics", List.of(index), + List.of(new DetectorRule(createdId1), new DetectorRule(createdId2)), + List.of())), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("cloudtrail"), List.of(), List.of(), List.of(), List.of(), List.of())), "cloudtrail"); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(cloudtrailDetector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + String monitorId = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + + indexDoc(index, "1", randomCloudtrailDoc("Richard", "CreateUser")); + executeAlertingMonitor(monitorId, Collections.emptyMap()); + Thread.sleep(1000); + indexDoc(index, "4", randomCloudtrailDoc("John", "CreateUser")); + executeAlertingMonitor(monitorId, Collections.emptyMap()); + Thread.sleep(1000); + + indexDoc(index, "2", randomCloudtrailDoc("Richard", "DeleteUser")); + executeAlertingMonitor(monitorId, Collections.emptyMap()); + + Thread.sleep(5000); + OpenSearchRestTestCase.waitUntil( + () -> { + try { + Request request1 = new Request("GET", "/_plugins/_security_analytics/correlationAlerts"); + Response getCorrelationAlertResp = client().performRequest(request1); + Map responseGetCorrelationAlertMap = entityAsMap(getCorrelationAlertResp); + List correlationAlerts = (List) responseGetCorrelationAlertMap.get("correlationAlerts"); + // Execute CreateMappingsAction to add alias mapping for index + Thread.sleep(2000); + Request restRequest = new Request("POST", "/_plugins/_security_analytics/_acknowledge/correlationAlerts"); + restRequest.setJsonEntity( + "{\"alertIds\": [\"" + correlationAlerts.get(0).getId() + "\"]}" + ); + Response restResponse = client().performRequest(restRequest); + Map responseMap = entityAsMap(restResponse); + List results = (List) responseMap.get("acknowledged"); + if (results.size() == 1) { + Assert.assertTrue(true); + return true; + } + return false; + } catch (Exception ex) { + return false; + } + }, + 2, TimeUnit.MINUTES + ); + } +}