Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG_FIX] fix check for agg rules in detector trigger condition to create chained findings monitor #992

Merged
merged 7 commits into from
Apr 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,9 @@
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries;
import org.opensearch.securityanalytics.rules.backend.QueryBackend;
import org.opensearch.securityanalytics.rules.exceptions.SigmaConditionError;
import org.opensearch.securityanalytics.rules.exceptions.CompositeSigmaErrors;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService;
import org.opensearch.securityanalytics.util.DetectorIndices;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.ExceptionChecker;
import org.opensearch.securityanalytics.util.IndexUtils;
import org.opensearch.securityanalytics.util.MonitorService;
Expand All @@ -123,6 +121,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -828,20 +827,30 @@ private void addThreatIntelBasedDocLevelQueries(Detector detector, ActionListene
*/
private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
Detector detector,
WriteRequest.RefreshPolicy refreshPolicy,
RefreshPolicy refreshPolicy,
String monitorId,
RestRequest.Method restMethod
) {
Method restMethod,
List<Pair<String, Rule>> queries) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are multiple aggregation rules/queries, do the queried passed in here include all of them or only the aggregation rule that created the bucket level monitor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are only adding the tags of each aggregation query to the doc level monitor's DocLevelQuery tags. Btw each agg rule creates 1 bucket level monitor. For all bucket level monitors there is one chained findings doc level monitor

List<DocLevelMonitorInput> docLevelMonitorInputs = new ArrayList<>();
List<DocLevelQuery> docLevelQueries = new ArrayList<>();
String monitorName = detector.getName() + "_chained_findings";
String actualQuery = "_id:*";
Set<String> tags = new HashSet<>();
for (Pair<String, Rule> query: queries) {
if(query.getRight().isAggregationRule()) {
Rule rule = query.getRight();
tags.add(rule.getLevel());
tags.add(rule.getCategory());
tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList()));
}
}
tags.removeIf(Objects::isNull);
DocLevelQuery docLevelQuery = new DocLevelQuery(
monitorName,
monitorName + "doc",
Collections.emptyList(),
actualQuery,
Collections.emptyList()
new ArrayList<>(tags)
);
docLevelQueries.add(docLevelQuery);

Expand Down Expand Up @@ -901,8 +910,8 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
@Override
public void onResponse(Collection<IndexMonitorRequest> indexMonitorRequests) {
// if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST));
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && queries.stream().anyMatch(it -> it.getRight().isAggregationRule())) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, queries));
}
listener.onResponse(monitorRequests);
}
Expand Down Expand Up @@ -1058,7 +1067,7 @@ public void onFailure(Exception e) {
listener.onFailure(e);
}
});
} catch (CompositeSigmaErrors e) {
} catch (Exception e) {
log.error("Failed to create bucket level monitor request", e);
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,12 @@ public void onFailure(Exception e) {
});
}

public static List<String> getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(
Detector detector,
List<Pair<String, Rule>> rulesById,
public static List<String> getBucketLevelMonitorIds(
List<IndexMonitorResponse> monitorResponses
) {
List<String> aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById);
return monitorResponses.stream().filter(
// In the case of bucket level monitors rule id is trigger id
it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()
&& !it.getMonitor().getTriggers().isEmpty()
&& aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId())
).map(IndexMonitorResponse::getId).collect(Collectors.toList());
}
public static List<String> getAggRuleIdsConfiguredToTrigger(Detector detector, List<Pair<String, Rule>> rulesById) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.commons.alerting.model.ChainedMonitorFindings;
import org.opensearch.commons.alerting.model.CompositeInput;
import org.opensearch.commons.alerting.model.Delegate;
import org.opensearch.commons.alerting.model.Monitor.MonitorType;
import org.opensearch.commons.alerting.model.Sequence;
import org.opensearch.commons.alerting.model.Workflow;
import org.opensearch.commons.alerting.model.Workflow.WorkflowType;
Expand All @@ -34,12 +33,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger;
import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIds;

/**
* Alerting common clas used for workflow manipulation
Expand Down Expand Up @@ -101,7 +99,7 @@ public void upsertWorkflow(
monitorResponses.addAll(updatedMonitorResponses);
}
cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses));
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses));
}

IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds,
Expand Down Expand Up @@ -149,16 +147,21 @@ public void deleteWorkflow(String workflowId, ActionListener<DeleteWorkflowRespo
private IndexWorkflowRequest createWorkflowRequest(List<String> monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method,
ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId) {
AtomicInteger index = new AtomicInteger();
List<Delegate> delegates = monitorIds.stream().map(
monitorId -> {
ChainedMonitorFindings cmf = null;
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
}
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, cmf);
return delegate;
}
).collect(Collectors.toList());
List<Delegate> delegates = new ArrayList<>();
ChainedMonitorFindings cmf = null;
for (String monitorId : monitorIds) {
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
} else {
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, null);
delegates.add(delegate);
}
}
if (cmf != null) {
// Add cmf with maximum value on "index"
Delegate cmfDelegate = new Delegate(index.incrementAndGet(), cmfMonitorId, cmf);
delegates.add(cmfDelegate);
}

Sequence sequence = new Sequence(delegates);
CompositeInput compositeInput = new CompositeInput(sequence);
Expand Down
132 changes: 132 additions & 0 deletions src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.hc.core5.http.HttpStatus;
Expand All @@ -23,6 +26,7 @@
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.commons.alerting.model.Monitor;
import org.opensearch.commons.alerting.model.action.Action;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
Expand All @@ -37,7 +41,9 @@

import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings;
import static org.opensearch.securityanalytics.TestHelpers.randomAction;
import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers;
Expand Down Expand Up @@ -663,6 +669,132 @@ public void testAlertHistoryRollover_maxAge() throws IOException, InterruptedExc

restoreAlertsFindingsIMSettings();
}
/**
* 1. Creates detector with aggregation and prepackaged rules
* (sum rule - should match docIds: 1, 2, 3; maxRule - 4, 5, 6, 7; minRule - 7)
* 2. Verifies monitor execution
* 3. Verifies alerts
*
* @throws IOException
*/
public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException {
String index = createTestIndex(randomIndex(), windowsIndexMapping());

Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index + "\"," +
" \"rule_topic\":\"" + randomDetectorType() + "\", " +
" \"partial\":true" +
"}"
);

Response createMappingResponse = client().performRequest(createMappingRequest);

assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());

String infoOpCode = "Info";

String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode));


List<DetectorRule> detectorRules = List.of(new DetectorRule(sumRuleId));

DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules,
Collections.emptyList());
Detector detector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger("randomtrigegr", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))
);

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));


String request = "{\n" +
" \"query\" : {\n" +
" \"match_all\":{\n" +
" }\n" +
" }\n" +
"}";
SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true);

assertEquals(1, response.getHits().getTotalHits().value); // 5 for rules, 1 for match_all query in chained findings monitor

assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String detectorId = responseBody.get("_id").toString();
request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + detectorId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);
Map<String, List> updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

List<String> monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));

indexDoc(index, "1", randomDoc(2, 4, infoOpCode));
indexDoc(index, "2", randomDoc(3, 4, infoOpCode));

Map<String, Integer> numberOfMonitorTypes = new HashMap<>();

for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);
if (Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) {
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
assertEquals(5, noOfSigmaRuleMatches);
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());

Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);

assertNotNull(getFindingsBody);
assertEquals(1, getFindingsBody.get("total_findings"));

String findingDetectorId = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString();
assertEquals(detectorId, findingDetectorId);

String findingIndex = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString();
assertEquals(index, findingIndex);

List<String> docLevelFinding = new ArrayList<>();
List<Map<String, Object>> findings = (List) getFindingsBody.get("findings");


for (Map<String, Object> finding : findings) {
List<Map<String, Object>> queries = (List<Map<String, Object>>) finding.get("queries");
Set<String> findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet());

// In the case of bucket level monitors, queries will always contain one value
String aggRuleId = findingRuleIds.iterator().next();
List<String> findingDocs = (List<String>) finding.get("related_doc_ids");

if (aggRuleId.equals(sumRuleId)) {
assertTrue(List.of("1", "2", "3", "4", "5", "6", "7").containsAll(findingDocs));
}
}

assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding));

Map<String, String> params1 = new HashMap<>();
params1.put("detector_id", detectorId);
Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null);
Map<String, Object> getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert
}

public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, InterruptedException {
updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "1s");
Expand Down
Loading
Loading