Skip to content

Commit

Permalink
add field based rules support in correlation engine
Browse files Browse the repository at this point in the history
Signed-off-by: Subhobrata Dey <[email protected]>
  • Loading branch information
sbcd90 committed Nov 29, 2023
1 parent 43040d6 commit 5d5d6dd
Show file tree
Hide file tree
Showing 12 changed files with 984 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ public List<Setting<?>> getSettings() {
SecurityAnalyticsSettings.CORRELATION_HISTORY_RETENTION_PERIOD,
SecurityAnalyticsSettings.IS_CORRELATION_INDEX_SETTING,
SecurityAnalyticsSettings.CORRELATION_TIME_WINDOW,
SecurityAnalyticsSettings.ENABLE_AUTO_CORRELATIONS,
SecurityAnalyticsSettings.DEFAULT_MAPPING_SCHEMA,
SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE,
SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
*/
package org.opensearch.securityanalytics.correlation;

import kotlin.Pair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.join.ScoreMode;
Expand Down Expand Up @@ -49,7 +50,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;


Expand All @@ -63,6 +63,8 @@ public class JoinEngine {

private volatile long corrTimeWindow;

private volatile boolean enableAutoCorrelations;

private final TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction;

private final LogTypeService logTypeService;
Expand All @@ -71,18 +73,23 @@ public class JoinEngine {

public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry,
long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction,
LogTypeService logTypeService) {
LogTypeService logTypeService, boolean enableAutoCorrelations) {
this.client = client;
this.request = request;
this.xContentRegistry = xContentRegistry;
this.corrTimeWindow = corrTimeWindow;
this.correlateFindingAction = correlateFindingAction;
this.logTypeService = logTypeService;
this.enableAutoCorrelations = enableAutoCorrelations;
}

public void onSearchDetectorResponse(Detector detector, Finding finding) {
try {
generateAutoCorrelations(detector, finding);
if (enableAutoCorrelations) {
generateAutoCorrelations(detector, finding);
} else {
onAutoCorrelations(detector, finding, Map.of());
}
} catch (IOException ex) {
correlateFindingAction.onFailures(ex);
}
Expand Down Expand Up @@ -265,26 +272,35 @@ public void onFailure(Exception e) {
private void getValidDocuments(String detectorType, List<String> indices, List<CorrelationRule> correlationRules, List<String> relatedDocIds, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<CorrelationRule> validCorrelationRules = new ArrayList<>();
List<String> validFields = new ArrayList<>();

for (CorrelationRule rule: correlationRules) {
Optional<CorrelationQuery> query = rule.getCorrelationQueries().stream()
.filter(correlationQuery -> correlationQuery.getCategory().equals(detectorType)).findFirst();

if (query.isPresent()) {
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
.filter(QueryBuilders.termsQuery("_id", relatedDocIds))
.must(QueryBuilders.queryStringQuery(query.get().getQuery()));
.filter(QueryBuilders.termsQuery("_id", relatedDocIds));

if (query.get().getField() != null) {
queryBuilder = queryBuilder.must(QueryBuilders.existsQuery(query.get().getField()));
} else {
queryBuilder = queryBuilder.must(QueryBuilders.queryStringQuery(query.get().getQuery()));
}
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
searchSourceBuilder.fetchSource(false);
if (query.get().getField() != null) {
searchSourceBuilder.fetchField(query.get().getField());
}
searchSourceBuilder.size(10000);
SearchRequest searchRequest = new SearchRequest();
searchRequest.indices(indices.toArray(new String[]{}));
searchRequest.source(searchSourceBuilder);
searchRequest.preference(Preference.PRIMARY_FIRST.type());

validCorrelationRules.add(rule);
validFields.add(query.get().getField());
mSearchRequest.add(searchRequest);
}
}
Expand All @@ -294,7 +310,7 @@ private void getValidDocuments(String detectorType, List<String> indices, List<C
@Override
public void onResponse(MultiSearchResponse items) {
MultiSearchResponse.Item[] responses = items.getResponses();
List<CorrelationRule> filteredCorrelationRules = new ArrayList<>();
List<Triple<CorrelationRule, SearchHit[], String>> filteredCorrelationRules = new ArrayList<>();

int idx = 0;
for (MultiSearchResponse.Item response : responses) {
Expand All @@ -304,14 +320,17 @@ public void onResponse(MultiSearchResponse items) {
}

if (response.getResponse().getHits().getTotalHits().value > 0L) {
filteredCorrelationRules.add(validCorrelationRules.get(idx));
filteredCorrelationRules.add(Triple.of(validCorrelationRules.get(idx),
response.getResponse().getHits().getHits(), validFields.get(idx)));
}
++idx;
}

Map<String, List<CorrelationQuery>> categoryToQueriesMap = new HashMap<>();
for (CorrelationRule rule: filteredCorrelationRules) {
List<CorrelationQuery> queries = rule.getCorrelationQueries();
Map<String, Long> categoryToTimeWindowMap = new HashMap<>();
for (Triple<CorrelationRule, SearchHit[], String> rule: filteredCorrelationRules) {
List<CorrelationQuery> queries = rule.getLeft().getCorrelationQueries();
Long timeWindow = rule.getLeft().getCorrTimeWindow();

for (CorrelationQuery query: queries) {
List<CorrelationQuery> correlationQueries;
Expand All @@ -320,12 +339,36 @@ public void onResponse(MultiSearchResponse items) {
} else {
correlationQueries = new ArrayList<>();
}
correlationQueries.add(query);
if (categoryToTimeWindowMap.containsKey(query.getCategory())) {
categoryToTimeWindowMap.put(query.getCategory(), Math.max(timeWindow, categoryToTimeWindowMap.get(query.getCategory())));
} else {
categoryToTimeWindowMap.put(query.getCategory(), timeWindow);
}

if (query.getField() == null) {
correlationQueries.add(query);
} else {
SearchHit[] hits = rule.getMiddle();
StringBuilder qb = new StringBuilder(query.getField()).append(":(");
for (int i = 0; i < hits.length; ++i) {
String value = hits[i].field(rule.getRight()).getValue();
qb.append(value);
if (i < hits.length-1) {
qb.append(" OR ");
} else {
qb.append(")");
}
}
if (query.getQuery() != null) {
qb.append(" AND ").append(query.getQuery());
}
correlationQueries.add(new CorrelationQuery(query.getIndex(), qb.toString(), query.getCategory(), null));
}
categoryToQueriesMap.put(query.getCategory(), correlationQueries);
}
}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap,
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()),
searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap,
filteredCorrelationRules.stream().map(Triple::getLeft).map(CorrelationRule::getId).collect(Collectors.toList()),
autoCorrelations
);
}
Expand All @@ -348,15 +391,15 @@ public void onFailure(Exception e) {
* this method searches for parent findings given the log category & correlation time window & collects all related docs
* for them.
*/
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, Map<String, Long> categoryToTimeWindowMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<Pair<String, List<CorrelationQuery>>> categoryToQueriesPairs = new ArrayList<>();

for (Map.Entry<String, List<CorrelationQuery>> categoryToQueries: categoryToQueriesMap.entrySet()) {
RangeQueryBuilder queryBuilder = QueryBuilders.rangeQuery("timestamp")
.gte(findingTimestamp - corrTimeWindow)
.lte(findingTimestamp + corrTimeWindow);
.gte(findingTimestamp - categoryToTimeWindowMap.get(categoryToQueries.getKey()))
.lte(findingTimestamp + categoryToTimeWindowMap.get(categoryToQueries.getKey()));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
Expand All @@ -368,7 +411,7 @@ private void searchFindingsByTimestamp(String detectorType, Map<String, List<Cor
searchRequest.source(searchSourceBuilder);
searchRequest.preference(Preference.PRIMARY_FIRST.type());
mSearchRequest.add(searchRequest);
categoryToQueriesPairs.add(new Pair<>(categoryToQueries.getKey(), categoryToQueries.getValue()));
categoryToQueriesPairs.add(Pair.of(categoryToQueries.getKey(), categoryToQueries.getValue()));
}

if (!mSearchRequest.requests().isEmpty()) {
Expand All @@ -392,17 +435,17 @@ public void onResponse(MultiSearchResponse items) {
.map(Object::toString).collect(Collectors.toList()));
}

List<CorrelationQuery> correlationQueries = categoryToQueriesPairs.get(idx).getSecond();
List<CorrelationQuery> correlationQueries = categoryToQueriesPairs.get(idx).getValue();
List<String> indices = correlationQueries.stream().map(CorrelationQuery::getIndex).collect(Collectors.toList());
List<String> queries = correlationQueries.stream().map(CorrelationQuery::getQuery).collect(Collectors.toList());
relatedDocsMap.put(categoryToQueriesPairs.get(idx).getFirst(),
relatedDocsMap.put(categoryToQueriesPairs.get(idx).getKey(),
new DocSearchCriteria(
indices,
queries,
relatedDocIds));
++idx;
}
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules, autoCorrelations);
searchDocsWithFilterKeys(detectorType, relatedDocsMap, categoryToTimeWindowMap, correlationRules, autoCorrelations);
}

@Override
Expand All @@ -422,7 +465,7 @@ public void onFailure(Exception e) {
/**
* Given the related docs from parent findings, this method filters only those related docs which match parent join criteria.
*/
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, Map<String, Long> categoryToTimeWindowMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

Expand All @@ -433,6 +476,7 @@ private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearch
for (String query: docSearchCriteria.getValue().queries) {
queryBuilder = queryBuilder.should(QueryBuilders.queryStringQuery(query));
}
queryBuilder.minimumShouldMatch(1).boost(1.0f);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
Expand Down Expand Up @@ -470,7 +514,7 @@ public void onResponse(MultiSearchResponse items) {
filteredRelatedDocIds.put(categories.get(idx), docIds);
++idx;
}
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules, autoCorrelations);
getCorrelatedFindings(detectorType, filteredRelatedDocIds, categoryToTimeWindowMap, correlationRules, autoCorrelations);
}

@Override
Expand All @@ -491,16 +535,16 @@ public void onFailure(Exception e) {
* Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for
* the finding to be correlated.
*/
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, Map<String, Long> categoryToTimeWindowMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

for (Map.Entry<String, List<String>> relatedDocIds: filteredRelatedDocIds.entrySet()) {
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
.filter(QueryBuilders.rangeQuery("timestamp")
.gte(findingTimestamp - corrTimeWindow)
.lte(findingTimestamp + corrTimeWindow))
.gte(findingTimestamp - categoryToTimeWindowMap.get(relatedDocIds.getKey()))
.lte(findingTimestamp + categoryToTimeWindowMap.get(relatedDocIds.getKey())))
.must(QueryBuilders.termsQuery("correlated_doc_ids", relatedDocIds.getValue()));

if (relatedDocIds.getKey().equals(detectorType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,53 @@ public class CorrelationQuery implements Writeable, ToXContentObject {
private static final String QUERY = "query";
private static final String CATEGORY = "category";

private static final String FIELD = "field";

private String index;

private String query;

private String category;

public CorrelationQuery(String index, String query, String category) {
private String field;

public CorrelationQuery(String index, String query, String category, String field) {
this.index = index;
this.query = query;
this.category = category;
this.field = field;
}

public CorrelationQuery(StreamInput sin) throws IOException {
this(sin.readString(), sin.readString(), sin.readString());
this(sin.readString(), sin.readOptionalString(), sin.readString(), sin.readOptionalString());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(index);
out.writeString(query);
out.writeOptionalString(query);
out.writeString(category);
out.writeOptionalString(field);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INDEX, index).field(QUERY, query).field(CATEGORY, category);
builder.field(INDEX, index).field(CATEGORY, category);
if (query != null) {
builder.field(QUERY, query);
}
if (field != null) {
builder.field(FIELD, field);
}
return builder.endObject();
}

public static CorrelationQuery parse(XContentParser xcp) throws IOException {
String index = null;
String query = null;
String category = null;
String field = null;

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -72,11 +85,14 @@ public static CorrelationQuery parse(XContentParser xcp) throws IOException {
case CATEGORY:
category = xcp.text();
break;
case FIELD:
field = xcp.text();
break;
default:
xcp.skipChildren();
}
}
return new CorrelationQuery(index, query, category);
return new CorrelationQuery(index, query, category, field);
}

public static CorrelationQuery readFrom(StreamInput sin) throws IOException {
Expand All @@ -94,4 +110,8 @@ public String getQuery() {
public String getCategory() {
return category;
}

public String getField() {
return field;
}
}
Loading

0 comments on commit 5d5d6dd

Please sign in to comment.