Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backports #992 to 2.x #1002

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,9 @@
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend;
import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries;
import org.opensearch.securityanalytics.rules.backend.QueryBackend;
import org.opensearch.securityanalytics.rules.exceptions.SigmaError;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService;
import org.opensearch.securityanalytics.util.DetectorIndices;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.ExceptionChecker;
import org.opensearch.securityanalytics.util.IndexUtils;
import org.opensearch.securityanalytics.util.MonitorService;
Expand All @@ -122,6 +120,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -827,20 +826,30 @@ private void addThreatIntelBasedDocLevelQueries(Detector detector, ActionListene
*/
private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
Detector detector,
WriteRequest.RefreshPolicy refreshPolicy,
RefreshPolicy refreshPolicy,
String monitorId,
RestRequest.Method restMethod
) {
Method restMethod,
List<Pair<String, Rule>> queries) {
List<DocLevelMonitorInput> docLevelMonitorInputs = new ArrayList<>();
List<DocLevelQuery> docLevelQueries = new ArrayList<>();
String monitorName = detector.getName() + "_chained_findings";
String actualQuery = "_id:*";
Set<String> tags = new HashSet<>();
for (Pair<String, Rule> query: queries) {
if(query.getRight().isAggregationRule()) {
Rule rule = query.getRight();
tags.add(rule.getLevel());
tags.add(rule.getCategory());
tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList()));
}
}
tags.removeIf(Objects::isNull);
DocLevelQuery docLevelQuery = new DocLevelQuery(
monitorName,
monitorName + "doc",
Collections.emptyList(),
actualQuery,
Collections.emptyList()
new ArrayList<>(tags)
);
docLevelQueries.add(docLevelQuery);

Expand Down Expand Up @@ -900,8 +909,8 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
@Override
public void onResponse(Collection<IndexMonitorRequest> indexMonitorRequests) {
// if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST));
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && queries.stream().anyMatch(it -> it.getRight().isAggregationRule())) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, queries));
}
listener.onResponse(monitorRequests);
}
Expand Down Expand Up @@ -1053,7 +1062,7 @@ public void onFailure(Exception e) {
listener.onFailure(e);
}
});
} catch (SigmaError e) {
} catch (Exception e) {
log.error("Failed to create bucket level monitor request", e);
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,12 @@ public void onFailure(Exception e) {
});
}

public static List<String> getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(
Detector detector,
List<Pair<String, Rule>> rulesById,
public static List<String> getBucketLevelMonitorIds(
List<IndexMonitorResponse> monitorResponses
) {
List<String> aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById);
return monitorResponses.stream().filter(
// In the case of bucket level monitors rule id is trigger id
it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()
&& !it.getMonitor().getTriggers().isEmpty()
&& aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId())
).map(IndexMonitorResponse::getId).collect(Collectors.toList());
}
public static List<String> getAggRuleIdsConfiguredToTrigger(Detector detector, List<Pair<String, Rule>> rulesById) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.commons.alerting.model.ChainedMonitorFindings;
import org.opensearch.commons.alerting.model.CompositeInput;
import org.opensearch.commons.alerting.model.Delegate;
import org.opensearch.commons.alerting.model.Monitor.MonitorType;
import org.opensearch.commons.alerting.model.Sequence;
import org.opensearch.commons.alerting.model.Workflow;
import org.opensearch.commons.alerting.model.Workflow.WorkflowType;
Expand All @@ -34,12 +33,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger;
import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIds;

/**
* Alerting common clas used for workflow manipulation
Expand Down Expand Up @@ -101,7 +99,7 @@ public void upsertWorkflow(
monitorResponses.addAll(updatedMonitorResponses);
}
cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses));
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses));
}

IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds,
Expand Down Expand Up @@ -149,16 +147,21 @@ public void deleteWorkflow(String workflowId, ActionListener<DeleteWorkflowRespo
private IndexWorkflowRequest createWorkflowRequest(List<String> monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method,
ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId) {
AtomicInteger index = new AtomicInteger();
List<Delegate> delegates = monitorIds.stream().map(
monitorId -> {
ChainedMonitorFindings cmf = null;
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
}
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, cmf);
return delegate;
}
).collect(Collectors.toList());
List<Delegate> delegates = new ArrayList<>();
ChainedMonitorFindings cmf = null;
for (String monitorId : monitorIds) {
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
} else {
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, null);
delegates.add(delegate);
}
}
if (cmf != null) {
// Add cmf with maximum value on "index"
Delegate cmfDelegate = new Delegate(index.incrementAndGet(), cmfMonitorId, cmf);
delegates.add(cmfDelegate);
}

Sequence sequence = new Sequence(delegates);
CompositeInput compositeInput = new CompositeInput(sequence);
Expand Down
132 changes: 132 additions & 0 deletions src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.http.HttpStatus;
import org.apache.http.entity.StringEntity;
Expand All @@ -22,6 +25,7 @@
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.commons.alerting.model.Monitor;
import org.opensearch.commons.alerting.model.action.Action;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
Expand All @@ -36,7 +40,9 @@

import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings;
import static org.opensearch.securityanalytics.TestHelpers.randomAction;
import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers;
Expand Down Expand Up @@ -662,6 +668,132 @@ public void testAlertHistoryRollover_maxAge() throws IOException, InterruptedExc

restoreAlertsFindingsIMSettings();
}
/**
* 1. Creates detector with aggregation and prepackaged rules
* (sum rule - should match docIds: 1, 2, 3; maxRule - 4, 5, 6, 7; minRule - 7)
* 2. Verifies monitor execution
* 3. Verifies alerts
*
* @throws IOException
*/
public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException {
String index = createTestIndex(randomIndex(), windowsIndexMapping());

Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index + "\"," +
" \"rule_topic\":\"" + randomDetectorType() + "\", " +
" \"partial\":true" +
"}"
);

Response createMappingResponse = client().performRequest(createMappingRequest);

assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());

String infoOpCode = "Info";

String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode));


List<DetectorRule> detectorRules = List.of(new DetectorRule(sumRuleId));

DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules,
Collections.emptyList());
Detector detector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger("randomtrigegr", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))
);

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));


String request = "{\n" +
" \"query\" : {\n" +
" \"match_all\":{\n" +
" }\n" +
" }\n" +
"}";
SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true);

assertEquals(1, response.getHits().getTotalHits().value); // 5 for rules, 1 for match_all query in chained findings monitor

assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String detectorId = responseBody.get("_id").toString();
request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + detectorId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);
Map<String, List> updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

List<String> monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));

indexDoc(index, "1", randomDoc(2, 4, infoOpCode));
indexDoc(index, "2", randomDoc(3, 4, infoOpCode));

Map<String, Integer> numberOfMonitorTypes = new HashMap<>();

for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);
if (Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) {
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
assertEquals(5, noOfSigmaRuleMatches);
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());

Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);

assertNotNull(getFindingsBody);
assertEquals(1, getFindingsBody.get("total_findings"));

String findingDetectorId = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString();
assertEquals(detectorId, findingDetectorId);

String findingIndex = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString();
assertEquals(index, findingIndex);

List<String> docLevelFinding = new ArrayList<>();
List<Map<String, Object>> findings = (List) getFindingsBody.get("findings");


for (Map<String, Object> finding : findings) {
List<Map<String, Object>> queries = (List<Map<String, Object>>) finding.get("queries");
Set<String> findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet());

// In the case of bucket level monitors, queries will always contain one value
String aggRuleId = findingRuleIds.iterator().next();
List<String> findingDocs = (List<String>) finding.get("related_doc_ids");

if (aggRuleId.equals(sumRuleId)) {
assertTrue(List.of("1", "2", "3", "4", "5", "6", "7").containsAll(findingDocs));
}
}

assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding));

Map<String, String> params1 = new HashMap<>();
params1.put("detector_id", detectorId);
Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null);
Map<String, Object> getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert
}

public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, InterruptedException {
updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "1s");
Expand Down
Loading
Loading