Skip to content

Commit

Permalink
[BUG_FIX] fix check for agg rules in detector trigger condition to cr…
Browse files Browse the repository at this point in the history
…eate chained findings monitor (#992) (#1002)

* remove chekc for agg rules in detector trigger condition to create bucket level monitor

* add agg rules tags in chained monitor query to match trigger condition of detector

* fix check to evaluate agg rules present when creating chained findings monitor

* fix tests where check on group by trigger existed earlier

* fix race condition while creating first monitor

* add test to verify detector trigger function for aggregation rules

* revert step listener change

---------

Signed-off-by: Surya Sashank Nistala <[email protected]>
(cherry picked from commit 07bf73f)
  • Loading branch information
eirsep authored and github-actions[bot] committed Apr 27, 2024
1 parent 14f1e20 commit 7acd535
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,9 @@
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend;
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries;
import org.opensearch.securityanalytics.rules.backend.QueryBackend;
import org.opensearch.securityanalytics.rules.exceptions.SigmaError;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService;
import org.opensearch.securityanalytics.util.DetectorIndices;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.ExceptionChecker;
import org.opensearch.securityanalytics.util.IndexUtils;
import org.opensearch.securityanalytics.util.MonitorService;
Expand All @@ -122,6 +120,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -827,20 +826,30 @@ private void addThreatIntelBasedDocLevelQueries(Detector detector, ActionListene
*/
private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
Detector detector,
WriteRequest.RefreshPolicy refreshPolicy,
RefreshPolicy refreshPolicy,
String monitorId,
RestRequest.Method restMethod
) {
Method restMethod,
List<Pair<String, Rule>> queries) {
List<DocLevelMonitorInput> docLevelMonitorInputs = new ArrayList<>();
List<DocLevelQuery> docLevelQueries = new ArrayList<>();
String monitorName = detector.getName() + "_chained_findings";
String actualQuery = "_id:*";
Set<String> tags = new HashSet<>();
for (Pair<String, Rule> query: queries) {
if(query.getRight().isAggregationRule()) {
Rule rule = query.getRight();
tags.add(rule.getLevel());
tags.add(rule.getCategory());
tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList()));
}
}
tags.removeIf(Objects::isNull);
DocLevelQuery docLevelQuery = new DocLevelQuery(
monitorName,
monitorName + "doc",
Collections.emptyList(),
actualQuery,
Collections.emptyList()
new ArrayList<>(tags)
);
docLevelQueries.add(docLevelQuery);

Expand Down Expand Up @@ -900,8 +909,8 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
@Override
public void onResponse(Collection<IndexMonitorRequest> indexMonitorRequests) {
// if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST));
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && queries.stream().anyMatch(it -> it.getRight().isAggregationRule())) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, queries));
}
listener.onResponse(monitorRequests);
}
Expand Down Expand Up @@ -1053,7 +1062,7 @@ public void onFailure(Exception e) {
listener.onFailure(e);
}
});
} catch (SigmaError e) {
} catch (Exception e) {
log.error("Failed to create bucket level monitor request", e);
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,12 @@ public void onFailure(Exception e) {
});
}

public static List<String> getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(
Detector detector,
List<Pair<String, Rule>> rulesById,
public static List<String> getBucketLevelMonitorIds(
List<IndexMonitorResponse> monitorResponses
) {
List<String> aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById);
return monitorResponses.stream().filter(
// In the case of bucket level monitors rule id is trigger id
it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()
&& !it.getMonitor().getTriggers().isEmpty()
&& aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId())
).map(IndexMonitorResponse::getId).collect(Collectors.toList());
}
public static List<String> getAggRuleIdsConfiguredToTrigger(Detector detector, List<Pair<String, Rule>> rulesById) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.commons.alerting.model.ChainedMonitorFindings;
import org.opensearch.commons.alerting.model.CompositeInput;
import org.opensearch.commons.alerting.model.Delegate;
import org.opensearch.commons.alerting.model.Monitor.MonitorType;
import org.opensearch.commons.alerting.model.Sequence;
import org.opensearch.commons.alerting.model.Workflow;
import org.opensearch.commons.alerting.model.Workflow.WorkflowType;
Expand All @@ -34,12 +33,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger;
import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIds;

/**
* Alerting common clas used for workflow manipulation
Expand Down Expand Up @@ -101,7 +99,7 @@ public void upsertWorkflow(
monitorResponses.addAll(updatedMonitorResponses);
}
cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses));
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses));
}

IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds,
Expand Down Expand Up @@ -149,16 +147,21 @@ public void deleteWorkflow(String workflowId, ActionListener<DeleteWorkflowRespo
private IndexWorkflowRequest createWorkflowRequest(List<String> monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method,
ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId) {
AtomicInteger index = new AtomicInteger();
List<Delegate> delegates = monitorIds.stream().map(
monitorId -> {
ChainedMonitorFindings cmf = null;
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
}
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, cmf);
return delegate;
}
).collect(Collectors.toList());
List<Delegate> delegates = new ArrayList<>();
ChainedMonitorFindings cmf = null;
for (String monitorId : monitorIds) {
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
} else {
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, null);
delegates.add(delegate);
}
}
if (cmf != null) {
// Add cmf with maximum value on "index"
Delegate cmfDelegate = new Delegate(index.incrementAndGet(), cmfMonitorId, cmf);
delegates.add(cmfDelegate);
}

Sequence sequence = new Sequence(delegates);
CompositeInput compositeInput = new CompositeInput(sequence);
Expand Down
132 changes: 132 additions & 0 deletions src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.http.HttpStatus;
import org.apache.http.entity.StringEntity;
Expand All @@ -22,6 +25,7 @@
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.commons.alerting.model.Monitor;
import org.opensearch.commons.alerting.model.action.Action;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
Expand All @@ -36,7 +40,9 @@

import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings;
import static org.opensearch.securityanalytics.TestHelpers.randomAction;
import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers;
Expand Down Expand Up @@ -662,6 +668,132 @@ public void testAlertHistoryRollover_maxAge() throws IOException, InterruptedExc

restoreAlertsFindingsIMSettings();
}
/**
* 1. Creates detector with aggregation and prepackaged rules
* (sum rule - should match docIds: 1, 2, 3; maxRule - 4, 5, 6, 7; minRule - 7)
* 2. Verifies monitor execution
* 3. Verifies alerts
*
* @throws IOException
*/
public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException {
String index = createTestIndex(randomIndex(), windowsIndexMapping());

Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index + "\"," +
" \"rule_topic\":\"" + randomDetectorType() + "\", " +
" \"partial\":true" +
"}"
);

Response createMappingResponse = client().performRequest(createMappingRequest);

assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());

String infoOpCode = "Info";

String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode));


List<DetectorRule> detectorRules = List.of(new DetectorRule(sumRuleId));

DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules,
Collections.emptyList());
Detector detector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger("randomtrigegr", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))
);

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));


String request = "{\n" +
" \"query\" : {\n" +
" \"match_all\":{\n" +
" }\n" +
" }\n" +
"}";
SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true);

assertEquals(1, response.getHits().getTotalHits().value); // 5 for rules, 1 for match_all query in chained findings monitor

assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String detectorId = responseBody.get("_id").toString();
request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + detectorId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);
Map<String, List> updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

List<String> monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));

indexDoc(index, "1", randomDoc(2, 4, infoOpCode));
indexDoc(index, "2", randomDoc(3, 4, infoOpCode));

Map<String, Integer> numberOfMonitorTypes = new HashMap<>();

for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);
if (Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) {
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
assertEquals(5, noOfSigmaRuleMatches);
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());

Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);

assertNotNull(getFindingsBody);
assertEquals(1, getFindingsBody.get("total_findings"));

String findingDetectorId = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString();
assertEquals(detectorId, findingDetectorId);

String findingIndex = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString();
assertEquals(index, findingIndex);

List<String> docLevelFinding = new ArrayList<>();
List<Map<String, Object>> findings = (List) getFindingsBody.get("findings");


for (Map<String, Object> finding : findings) {
List<Map<String, Object>> queries = (List<Map<String, Object>>) finding.get("queries");
Set<String> findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet());

// In the case of bucket level monitors, queries will always contain one value
String aggRuleId = findingRuleIds.iterator().next();
List<String> findingDocs = (List<String>) finding.get("related_doc_ids");

if (aggRuleId.equals(sumRuleId)) {
assertTrue(List.of("1", "2", "3", "4", "5", "6", "7").containsAll(findingDocs));
}
}

assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding));

Map<String, String> params1 = new HashMap<>();
params1.put("detector_id", detectorId);
Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null);
Map<String, Object> getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert
}

public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, InterruptedException {
updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "1s");
Expand Down
Loading

0 comments on commit 7acd535

Please sign in to comment.