diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 3b4314e12..83f3713e7 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -21,6 +21,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.alerting.action.PublishFindingsRequest; import org.opensearch.commons.alerting.model.Finding; +import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; @@ -32,9 +33,11 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction; import org.opensearch.securityanalytics.util.AutoCorrelationsRepo; @@ -68,18 +71,27 @@ public class JoinEngine { private final LogTypeService logTypeService; + private final NotificationService notificationService; + + private volatile TimeValue indexTimeout; + private static final Logger log = LogManager.getLogger(JoinEngine.class); + private final User user; + public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry, - long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, - LogTypeService logTypeService, boolean enableAutoCorrelations) { + long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, + LogTypeService logTypeService, boolean enableAutoCorrelations, NotificationService notificationService, User user) { this.client = client; this.request = request; this.xContentRegistry = xContentRegistry; this.corrTimeWindow = corrTimeWindow; + this.indexTimeout = indexTimeout; this.correlateFindingAction = correlateFindingAction; this.logTypeService = logTypeService; this.enableAutoCorrelations = enableAutoCorrelations; + this.notificationService = notificationService; + this.user = user; } public void onSearchDetectorResponse(Detector detector, Finding finding) { @@ -349,7 +361,7 @@ private void getValidDocuments(String detectorType, List indices, List it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), + filteredCorrelationRules.stream().map(it -> it.correlationRule).collect(Collectors.toList()), autoCorrelations ); }, this::onFailure)); @@ -362,7 +374,7 @@ private void getValidDocuments(String detectorType, List indices, List> categoryToQueriesMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void searchFindingsByTimestamp(String detectorType, Map> categoryToQueriesMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List>> categoryToQueriesPairs = new ArrayList<>(); @@ -418,14 +430,14 @@ private void searchFindingsByTimestamp(String detectorType, Map relatedDocsMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void searchDocsWithFilterKeys(String detectorType, Map relatedDocsMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -476,7 +488,7 @@ private void searchDocsWithFilterKeys(String detectorType, Map> filteredRelatedDocIds, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void getCorrelatedFindings(String detectorType, Map> filteredRelatedDocIds, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -549,10 +561,10 @@ private void getCorrelatedFindings(String detectorType, Map correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); } } - correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList())); }, this::onFailure)); } else { - getTimestampFeature(detectorType, correlationRules, autoCorrelations); + getTimestampFeature(detectorType, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()), autoCorrelations); } } diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 910794556..a76dd4d0b 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -35,6 +35,7 @@ import org.opensearch.commons.alerting.action.PublishFindingsRequest; import org.opensearch.commons.alerting.action.SubscribeFindingsResponse; import org.opensearch.commons.alerting.action.AlertingActions; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -49,6 +50,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.correlation.JoinEngine; import org.opensearch.securityanalytics.correlation.VectorEmbeddingsEngine; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; @@ -99,6 +101,8 @@ public class TransportCorrelateFindingAction extends HandledTransportAction actionListener) { try { PublishFindingsRequest transformedRequest = transformRequest(request); - AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); + AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, readUserFromThreadContext(this.threadPool), actionListener); if (!this.correlationIndices.correlationIndexExists()) { try { @@ -146,7 +151,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { @@ -168,6 +172,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + if (createIndexResponse.isAcknowledged()) { + IndexUtils.correlationAlertIndexUpdated(); + } else { + correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, correlateFindingAction::onFailures)); + } catch (Exception ex) { + correlateFindingAction.onFailures(ex); + } + } } else { correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } @@ -193,14 +210,12 @@ public class AsyncCorrelateFindingAction { private final AtomicBoolean counter = new AtomicBoolean(); private final Task task; - AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, ActionListener listener) { + AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, User user, ActionListener listener) { this.task = task; this.request = request; this.listener = listener; - this.response =new AtomicReference<>(); - - this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this, logTypeService, enableAutoCorrelation); + this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, notificationService, user); this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this); }