From 0ef8543b609673bb9ef404ee6c8945725800a080 Mon Sep 17 00:00:00 2001 From: Megha Goyal <56077967+goyamegh@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:57:07 -0800 Subject: [PATCH 01/12] [BUG] ArrayIndexOutOfBoundsException for inconsistent detector index behavior (#843) * Catch ArrayIndexOutOfBoundsException when detector is missing Signed-off-by: Megha Goyal * Add a check on SearchHits.getHits() length Signed-off-by: Megha Goyal * Remove index out of bounds exception Signed-off-by: Megha Goyal --------- Signed-off-by: Megha Goyal --- .../transport/TransportCorrelateFindingAction.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index b7a906159..e79af28d3 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -261,7 +261,8 @@ public void onResponse(SearchResponse response) { } SearchHits hits = response.getHits(); - if (hits.getTotalHits().value == 1) { + // Detectors Index hits count could be more even if we fetch one + if (hits.getTotalHits().value >= 1 && hits.getHits().length > 0) { try { SearchHit hit = hits.getAt(0); @@ -272,6 +273,7 @@ public void onResponse(SearchResponse response) { Detector detector = Detector.docParse(xcp, hit.getId(), hit.getVersion()); joinEngine.onSearchDetectorResponse(detector, finding); } catch (IOException e) { + log.error("IOException for request {}", searchRequest.toString(), e); onFailures(e); } } else { From 8d19912fe1515a515b3d6a4c3f46064ab8047bb0 Mon Sep 17 00:00:00 2001 From: Megha Goyal <56077967+goyamegh@users.noreply.github.com> Date: Wed, 14 Feb 2024 14:54:41 -0800 Subject: [PATCH 02/12] Fail the flow the when detectot type is missing in the log types index (#845) Signed-off-by: Megha Goyal --- .../correlation/VectorEmbeddingsEngine.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java index 9a423f6fb..0f9866766 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java @@ -229,6 +229,11 @@ public void onFailure(Exception e) { } public void insertOrphanFindings(String detectorType, Finding finding, float timestampFeature, Map logTypes) { + if (logTypes.get(detectorType) == null) { + log.error("LogTypes Index is missing the detector type {}", detectorType); + correlateFindingAction.onFailures(new OpenSearchStatusException("LogTypes Index is missing the detector type", RestStatus.INTERNAL_SERVER_ERROR)); + } + Map tags = logTypes.get(detectorType).getTags(); String correlationId = tags.get("correlation_id").toString(); From 1e0f1adf756e5394e4ae02144a8311cd483e7902 Mon Sep 17 00:00:00 2001 From: Megha Goyal <56077967+goyamegh@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:42:25 -0800 Subject: [PATCH 03/12] Add goyamegh as a maintainer (#868) Signed-off-by: Megha Goyal --- .github/CODEOWNERS | 2 +- MAINTAINERS.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 875b82375..1793dcf5d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @amsiglan @AWSHurneyt @getsaurabh02 @lezzago @praveensameneni @sbcd90 @eirsep @jowg-amazon @engechas +* @amsiglan @AWSHurneyt @getsaurabh02 @lezzago @praveensameneni @sbcd90 @eirsep @jowg-amazon @engechas @goyamegh diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 11720e824..1393aa3bf 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -15,3 +15,4 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Saurabh Singh | [getsaurabh02](https://github.com/getsaurabh02) | Amazon | | Joanne Wang | [jowg-amazon](https://github.com/jowg-amazon) | Amazon | | Chase Engelbrecht | [engechas](https://github.com/engechas) | Amazon | +| Megha Goyal | [goyamegh](https://github.com/goyamegh) | Amazon | From 8ef0a3f88e1b460bde3ad1f14e10bb0287337578 Mon Sep 17 00:00:00 2001 From: Riya <69919272+riysaxen-amzn@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:26:01 -0800 Subject: [PATCH 04/12] added riysaxen-amzn as a maintainer (#869) Signed-off-by: Riya Saxena Signed-off-by: AWSHurneyt Co-authored-by: AWSHurneyt --- .github/CODEOWNERS | 2 +- MAINTAINERS.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1793dcf5d..f0b89b8ba 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @amsiglan @AWSHurneyt @getsaurabh02 @lezzago @praveensameneni @sbcd90 @eirsep @jowg-amazon @engechas @goyamegh +* @amsiglan @AWSHurneyt @getsaurabh02 @lezzago @praveensameneni @sbcd90 @eirsep @jowg-amazon @engechas @goyamegh @riysaxen-amzn diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 1393aa3bf..f49cd0d59 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -16,3 +16,4 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Joanne Wang | [jowg-amazon](https://github.com/jowg-amazon) | Amazon | | Chase Engelbrecht | [engechas](https://github.com/engechas) | Amazon | | Megha Goyal | [goyamegh](https://github.com/goyamegh) | Amazon | +| Riya Saxena | [riysaxen-amzn](https://github.com/riysaxen-amzn)) | Amazon | From 172d58de0ec6f3a8e455a9033973b2f61fc77d87 Mon Sep 17 00:00:00 2001 From: Surya Sashank Nistala Date: Fri, 1 Mar 2024 17:28:05 -0800 Subject: [PATCH 05/12] Remove blocking calls and change threat intel feed flow to event driven (#871) * remove actionGet() and change threat intel feed flow to event driven Signed-off-by: Surya Sashank Nistala * fix javadocs Signed-off-by: Surya Sashank Nistala * revert try catch removals Signed-off-by: Surya Sashank Nistala * use action listener wrap() in detector threat intel code paths Signed-off-by: Surya Sashank Nistala * add try catch Signed-off-by: Surya Sashank Nistala --------- Signed-off-by: Surya Sashank Nistala --- .../DetectorThreatIntelService.java | 43 ++---- .../ThreatIntelFeedDataService.java | 145 +++++++----------- .../action/TransportPutTIFJobAction.java | 112 +++++++------- .../threatIntel/common/TIFLockService.java | 50 ++---- .../jobscheduler/TIFJobParameterService.java | 94 ++++++------ .../jobscheduler/TIFJobRunner.java | 132 ++++++++-------- .../TransportIndexDetectorAction.java | 4 +- .../TransportSearchDetectorAction.java | 13 +- .../SecurityAnalyticsPluginTransportIT.java | 33 ---- .../common/ThreatIntelLockServiceTests.java | 9 +- 10 files changed, 265 insertions(+), 370 deletions(-) delete mode 100644 src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsPluginTransportIT.java diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java index df4971b66..e541ee36c 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java @@ -32,8 +32,6 @@ import java.util.Map; import java.util.Set; import java.util.UUID; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.opensearch.securityanalytics.model.Detector.DETECTORS_INDEX; @@ -121,35 +119,24 @@ public void createDocLevelQueryFromThreatIntel(List iocFieldL listener.onResponse(Collections.emptyList()); return; } - - CountDownLatch latch = new CountDownLatch(1); - threatIntelFeedDataService.getThreatIntelFeedData(new ActionListener<>() { - @Override - public void onResponse(List threatIntelFeedData) { - if (threatIntelFeedData.isEmpty()) { - listener.onResponse(Collections.emptyList()); - } else { - listener.onResponse( - createDocLevelQueriesFromThreatIntelList(iocFieldList, threatIntelFeedData, detector) - ); + threatIntelFeedDataService.getThreatIntelFeedData(ActionListener.wrap( + threatIntelFeedData -> { + if (threatIntelFeedData.isEmpty()) { + listener.onResponse(Collections.emptyList()); + } else { + listener.onResponse( + createDocLevelQueriesFromThreatIntelList(iocFieldList, threatIntelFeedData, detector) + ); + } + }, e -> { + log.error("Failed to get threat intel feeds for doc level query creation", e); + listener.onFailure(e); } - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to get threat intel feeds for doc level query creation", e); - listener.onFailure(e); - latch.countDown(); - } - }); - - latch.await(30, TimeUnit.SECONDS); - } catch (InterruptedException e) { - log.error("Failed to create doc level queries from threat intel feeds", e); + )); + } catch (Exception e) { + log.error("Failed to create doc level query from threat intel data", e); listener.onFailure(e); } - } private static String constructId(Detector detector, String iocType) { diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java index f37018ae5..b9d8aa3ea 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java @@ -34,12 +34,12 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.securityanalytics.model.ThreatIntelFeedData; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.threatIntel.action.PutTIFJobAction; import org.opensearch.securityanalytics.threatIntel.action.PutTIFJobRequest; import org.opensearch.securityanalytics.threatIntel.action.ThreatIntelIndicesResponse; -import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; import org.opensearch.securityanalytics.threatIntel.common.StashedThreadContext; -import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.common.TIFMetadata; import org.opensearch.securityanalytics.threatIntel.jobscheduler.TIFJobParameterService; import org.opensearch.securityanalytics.util.IndexUtils; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; @@ -56,7 +56,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.concurrent.CountDownLatch; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -104,21 +103,13 @@ public void getThreatIntelFeedData( ActionListener> listener ) { try { - String tifdIndex = getLatestIndexByCreationDate(); if (tifdIndex == null) { createThreatIntelFeedData(listener); } else { - SearchRequest searchRequest = new SearchRequest(tifdIndex); - searchRequest.source().size(9999); //TODO: convert to scroll - String finalTifdIndex = tifdIndex; - client.search(searchRequest, ActionListener.wrap(r -> listener.onResponse(ThreatIntelFeedDataUtils.getTifdList(r, xContentRegistry)), e -> { - log.error(String.format( - "Failed to fetch threat intel feed data from system index %s", finalTifdIndex), e); - listener.onFailure(e); - })); + fetchThreatIntelFeedDataFromIndex(tifdIndex, listener); } - } catch (InterruptedException e) { + } catch (Exception e) { log.error("Failed to get threat intel feed data", e); listener.onFailure(e); } @@ -150,21 +141,16 @@ public void createIndexIfNotExists(final String indexName, final ActionListener< .mapping(getIndexMapping()).timeout(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)); StashedThreadContext.run( client, - () -> client.admin().indices().create(createIndexRequest, new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { - if (response.isAcknowledged()) { - listener.onResponse(response); - } else { - onFailure(new OpenSearchStatusException("Threat intel feed index creation failed", RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }) + () -> client.admin().indices().create(createIndexRequest, + ActionListener.wrap( + response -> { + if (response.isAcknowledged()) + listener.onResponse(response); + else + listener.onFailure(new OpenSearchStatusException("Threat intel feed index creation failed", RestStatus.INTERNAL_SERVER_ERROR)); + + }, listener::onFailure + )) ); } @@ -223,28 +209,20 @@ public void parseAndSaveThreatIntelFeedDataCSV( } bulkRequestList.add(bulkRequest); - GroupedActionListener bulkResponseListener = new GroupedActionListener<>(new ActionListener<>() { - @Override - public void onResponse(Collection bulkResponses) { - int idx = 0; - for (BulkResponse response: bulkResponses) { - BulkRequest request = bulkRequestList.get(idx); - if (response.hasFailures()) { - throw new OpenSearchException( - "error occurred while ingesting threat intel feed data in {} with an error {}", - StringUtils.join(request.getIndices()), - response.buildFailureMessage() - ); - } + GroupedActionListener bulkResponseListener = new GroupedActionListener<>(ActionListener.wrap(bulkResponses -> { + int idx = 0; + for (BulkResponse response : bulkResponses) { + BulkRequest request = bulkRequestList.get(idx); + if (response.hasFailures()) { + throw new OpenSearchException( + "error occurred while ingesting threat intel feed data in {} with an error {}", + StringUtils.join(request.getIndices()), + response.buildFailureMessage() + ); } - listener.onResponse(new ThreatIntelIndicesResponse(true, List.of(indexName))); } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }, bulkRequestList.size()); + listener.onResponse(new ThreatIntelIndicesResponse(true, List.of(indexName))); + }, listener::onFailure), bulkRequestList.size()); for (int i = 0; i < bulkRequestList.size(); ++i) { saveTifds(bulkRequestList.get(i), timeout, bulkResponseListener); @@ -291,52 +269,47 @@ public void deleteThreatIntelDataIndex(final List indices) { .prepareDelete(indices.toArray(new String[0])) .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN) .setTimeout(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)) - .execute(new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse response) { - if (response.isAcknowledged() == false) { - onFailure(new OpenSearchException("failed to delete data[{}]", String.join(",", indices))); - } - } - - @Override - public void onFailure(Exception e) { - log.error("unknown exception:", e); - } - }) + .execute(ActionListener.wrap( + response -> { + if (response.isAcknowledged() == false) { + log.error(new OpenSearchException("failed to delete threat intel feed index[{}]", + String.join(",", indices))); + } + }, e -> log.error("failed to delete threat intel feed index [{}]", e) + )) ); } - private void createThreatIntelFeedData(ActionListener> listener) throws InterruptedException { - CountDownLatch countDownLatch = new CountDownLatch(1); + private void createThreatIntelFeedData(ActionListener> listener) { client.execute( PutTIFJobAction.INSTANCE, new PutTIFJobRequest("feed_updater", clusterSettings.get(SecurityAnalyticsSettings.TIF_UPDATE_INTERVAL)), - new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse acknowledgedResponse) { - log.debug("Acknowledged threat intel feed updater job created"); - countDownLatch.countDown(); - String tifdIndex = getLatestIndexByCreationDate(); - - SearchRequest searchRequest = new SearchRequest(tifdIndex); - searchRequest.source().size(9999); //TODO: convert to scroll - String finalTifdIndex = tifdIndex; - client.search(searchRequest, ActionListener.wrap(r -> listener.onResponse(ThreatIntelFeedDataUtils.getTifdList(r, xContentRegistry)), e -> { - log.error(String.format( - "Failed to fetch threat intel feed data from system index %s", finalTifdIndex), e); + ActionListener.wrap( + r -> { + if (false == r.isAcknowledged()) { + listener.onFailure(new Exception("Failed to acknowledge Put Tif job action")); + return; + } + log.debug("Acknowledged threat intel feed updater job created"); + String tifdIndex = getLatestIndexByCreationDate(); + fetchThreatIntelFeedDataFromIndex(tifdIndex, listener); + }, e -> { + log.debug("Failed to create threat intel feed updater job", e); listener.onFailure(e); - })); - } - - @Override - public void onFailure(Exception e) { - log.debug("Failed to create threat intel feed updater job", e); - countDownLatch.countDown(); - } - } + } + ) ); - countDownLatch.await(); + } + + private void fetchThreatIntelFeedDataFromIndex(String tifdIndex, ActionListener> listener) { + SearchRequest searchRequest = new SearchRequest(tifdIndex); + searchRequest.source().size(9999); //TODO: convert to scroll + String finalTifdIndex = tifdIndex; + client.search(searchRequest, ActionListener.wrap(r -> listener.onResponse(ThreatIntelFeedDataUtils.getTifdList(r, xContentRegistry)), e -> { + log.error(String.format( + "Failed to fetch threat intel feed data from system index %s", finalTifdIndex), e); + listener.onFailure(e); + })); } private String getIndexMapping() { diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java index 393a0f102..a50beda35 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/action/TransportPutTIFJobAction.java @@ -41,7 +41,6 @@ public class TransportPutTIFJobAction extends HandledTransportAction listener) { - lockService.acquireLock(request.getName(), LOCK_DURATION_IN_SECONDS, ActionListener.wrap(lock -> { - if (lock == null) { - listener.onFailure( - new ConcurrentModificationException("another processor is holding a lock on the resource. Try again later") - ); - log.error("another processor is a lock, BAD_REQUEST error", RestStatus.BAD_REQUEST); - return; - } - try { - internalDoExecute(request, lock, listener); - } catch (Exception e) { - lockService.releaseLock(lock); - listener.onFailure(e); - log.error("listener failed when executing", e); - } - }, exception -> { - listener.onFailure(exception); - log.error("execution failed", exception); - })); + try { + lockService.acquireLock(request.getName(), LOCK_DURATION_IN_SECONDS, ActionListener.wrap(lock -> { + if (lock == null) { + listener.onFailure( + new ConcurrentModificationException("another processor is holding a lock on the resource. Try again later") + ); + log.error("another processor is a lock, BAD_REQUEST error", RestStatus.BAD_REQUEST); + return; + } + try { + internalDoExecute(request, lock, listener); + } catch (Exception e) { + lockService.releaseLock(lock); + listener.onFailure(e); + log.error("listener failed when executing", e); + } + }, exception -> { + listener.onFailure(exception); + log.error("execution failed", exception); + })); + } catch (Exception e) { + log.error("Failed to acquire lock for job", e); + listener.onFailure(e); + } } /** @@ -103,16 +106,21 @@ protected void internalDoExecute( final LockModel lock, final ActionListener listener ) { - StepListener createIndexStep = new StepListener<>(); - tifJobParameterService.createJobIndexIfNotExists(createIndexStep); - createIndexStep.whenComplete(v -> { - TIFJobParameter tifJobParameter = TIFJobParameter.Builder.build(request); - tifJobParameterService.saveTIFJobParameter(tifJobParameter, postIndexingTifJobParameter(tifJobParameter, lock, listener)); + StepListener createIndexStepListener = new StepListener<>(); + createIndexStepListener.whenComplete(v -> { + try { + TIFJobParameter tifJobParameter = TIFJobParameter.Builder.build(request); + tifJobParameterService.saveTIFJobParameter(tifJobParameter, postIndexingTifJobParameter(tifJobParameter, lock, listener)); + } catch (Exception e) { + listener.onFailure(e); + } }, exception -> { lockService.releaseLock(lock); log.error("failed to release lock", exception); listener.onFailure(exception); }); + tifJobParameterService.createJobIndexIfNotExists(createIndexStepListener); + } /** @@ -124,40 +132,30 @@ protected ActionListener postIndexingTifJobParameter( final LockModel lock, final ActionListener listener ) { - return new ActionListener<>() { - @Override - public void onResponse(final IndexResponse indexResponse) { - AtomicReference lockReference = new AtomicReference<>(lock); - createThreatIntelFeedData(tifJobParameter, lockService.getRenewLockRunnable(lockReference), new ActionListener<>() { - @Override - public void onResponse(ThreatIntelIndicesResponse threatIntelIndicesResponse) { - if (threatIntelIndicesResponse.isAcknowledged()) { - lockService.releaseLock(lockReference.get()); - listener.onResponse(new AcknowledgedResponse(true)); - } else { - onFailure(new OpenSearchStatusException("creation of threat intel feed data failed", RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { + return ActionListener.wrap( + indexResponse -> { + AtomicReference lockReference = new AtomicReference<>(lock); + createThreatIntelFeedData(tifJobParameter, lockService.getRenewLockRunnable(lockReference), ActionListener.wrap( + threatIntelIndicesResponse -> { + if (threatIntelIndicesResponse.isAcknowledged()) { + lockService.releaseLock(lockReference.get()); + listener.onResponse(new AcknowledgedResponse(true)); + } else { + listener.onFailure(new OpenSearchStatusException("creation of threat intel feed data failed", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, listener::onFailure + )); + }, e -> { + lockService.releaseLock(lock); + if (e instanceof VersionConflictEngineException) { + log.error("tifJobParameter already exists"); + listener.onFailure(new ResourceAlreadyExistsException("tifJobParameter [{}] already exists", tifJobParameter.getName())); + } else { + log.error("Internal server error"); listener.onFailure(e); } - }); - } - - @Override - public void onFailure(final Exception e) { - lockService.releaseLock(lock); - if (e instanceof VersionConflictEngineException) { - log.error("tifJobParameter already exists"); - listener.onFailure(new ResourceAlreadyExistsException("tifJobParameter [{}] already exists", tifJobParameter.getName())); - } else { - log.error("Internal server error"); - listener.onFailure(e); } - } - }; + ); } protected void createThreatIntelFeedData(final TIFJobParameter tifJobParameter, final Runnable renewLock, final ActionListener listener) { diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java index 7ec4e94f3..98abf040a 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/TIFLockService.java @@ -5,18 +5,8 @@ package org.opensearch.securityanalytics.threatIntel.common; -import static org.opensearch.securityanalytics.SecurityAnalyticsPlugin.JOB_INDEX_NAME; - - -import java.time.Instant; -import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; - import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - import org.opensearch.OpenSearchException; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -25,6 +15,13 @@ import org.opensearch.jobscheduler.spi.utils.LockService; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import java.time.Instant; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.opensearch.securityanalytics.SecurityAnalyticsPlugin.JOB_INDEX_NAME; + /** * A wrapper of job scheduler's lock service */ @@ -48,52 +45,27 @@ public TIFLockService(final ClusterService clusterService, final Client client) this.lockService = new LockService(client, clusterService); } - /** - * Wrapper method of LockService#acquireLockWithId - * - * tif job uses its name as doc id in job scheduler. Therefore, we can use tif job name to acquire - * a lock on a tif job. - * - * @param tifJobName tifJobName to acquire lock on - * @param lockDurationSeconds the lock duration in seconds - * @param listener the listener - */ - public void acquireLock(final String tifJobName, final Long lockDurationSeconds, final ActionListener listener) { - lockService.acquireLockWithId(JOB_INDEX_NAME, lockDurationSeconds, tifJobName, listener); - } - /** * Synchronous method of #acquireLock * * @param tifJobName tifJobName to acquire lock on * @param lockDurationSeconds the lock duration in seconds - * @return lock model */ - public Optional acquireLock(final String tifJobName, final Long lockDurationSeconds) { + public void acquireLock(final String tifJobName, final Long lockDurationSeconds, ActionListener listener) { AtomicReference lockReference = new AtomicReference(); - CountDownLatch countDownLatch = new CountDownLatch(1); lockService.acquireLockWithId(JOB_INDEX_NAME, lockDurationSeconds, tifJobName, new ActionListener<>() { @Override public void onResponse(final LockModel lockModel) { lockReference.set(lockModel); - countDownLatch.countDown(); + listener.onResponse(lockReference.get()); } @Override public void onFailure(final Exception e) { - lockReference.set(null); - countDownLatch.countDown(); - log.error("aquiring lock failed", e); + log.error("Failed to acquire lock for tif job " + tifJobName, e); + listener.onFailure(e); } }); - - try { - countDownLatch.await(clusterService.getClusterSettings().get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT).getSeconds(), TimeUnit.SECONDS); - return Optional.ofNullable(lockReference.get()); - } catch (InterruptedException e) { - log.error("Waiting for the count down latch failed", e); - return Optional.empty(); - } } /** diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java index de9bb5365..55387cb35 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobParameterService.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.DocWriteRequest; import org.opensearch.action.StepListener; import org.opensearch.action.admin.indices.create.CreateIndexRequest; @@ -84,6 +85,7 @@ public void onFailure(final Exception e) { stepListener.onResponse(null); return; } + log.error("Failed to create security analytics job index", e); stepListener.onFailure(e); } })); @@ -104,82 +106,72 @@ private String getIndexMapping() { /** * Update jobSchedulerParameter in an index {@code TIFJobExtension.JOB_INDEX_NAME} + * * @param jobSchedulerParameter the jobSchedulerParameter */ public void updateJobSchedulerParameter(final TIFJobParameter jobSchedulerParameter, final ActionListener listener) { jobSchedulerParameter.setLastUpdateTime(Instant.now()); StashedThreadContext.run(client, () -> { try { - if (listener != null) { - client.prepareIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME) - .setId(jobSchedulerParameter.getName()) - .setOpType(DocWriteRequest.OpType.INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource(jobSchedulerParameter.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .execute(new ActionListener<>() { - @Override - public void onResponse(IndexResponse indexResponse) { - if (indexResponse.status().getStatus() >= 200 && indexResponse.status().getStatus() < 300) { - listener.onResponse(new ThreatIntelIndicesResponse(true, jobSchedulerParameter.getIndices())); - } else { - listener.onFailure(new OpenSearchStatusException("update of job scheduler parameter failed", RestStatus.INTERNAL_SERVER_ERROR)); - } + client.prepareIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME) + .setId(jobSchedulerParameter.getName()) + .setOpType(DocWriteRequest.OpType.INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(jobSchedulerParameter.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute(new ActionListener<>() { + @Override + public void onResponse(IndexResponse indexResponse) { + if (indexResponse.status().getStatus() >= 200 && indexResponse.status().getStatus() < 300) { + listener.onResponse(new ThreatIntelIndicesResponse(true, jobSchedulerParameter.getIndices())); + } else { + listener.onFailure(new OpenSearchStatusException("update of job scheduler parameter failed", RestStatus.INTERNAL_SERVER_ERROR)); } + } - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - } else { - client.prepareIndex(SecurityAnalyticsPlugin.JOB_INDEX_NAME) - .setId(jobSchedulerParameter.getName()) - .setOpType(DocWriteRequest.OpType.INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource(jobSchedulerParameter.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .execute().actionGet(); - } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); } catch (IOException e) { - throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); + log.error("failed to update job scheduler param for tif job", e); + listener.onFailure(e); } }); } /** * Get tif job from an index {@code TIFJobExtension.JOB_INDEX_NAME} + * * @param name the name of a tif job - * @return tif job - * @throws IOException exception */ - public TIFJobParameter getJobParameter(final String name) throws IOException { + public void getJobParameter(final String name, ActionListener listener) { GetRequest request = new GetRequest(SecurityAnalyticsPlugin.JOB_INDEX_NAME, name); - GetResponse response; - try { - response = StashedThreadContext.run(client, () -> client.get(request).actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT))); - if (response.isExists() == false) { - log.error("TIF job[{}] does not exist in an index[{}]", name, SecurityAnalyticsPlugin.JOB_INDEX_NAME); - return null; - } - } catch (IndexNotFoundException e) { - log.error("Index[{}] is not found", SecurityAnalyticsPlugin.JOB_INDEX_NAME); - return null; - } - - XContentParser parser = XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - response.getSourceAsBytesRef() - ); - return TIFJobParameter.PARSER.parse(parser, null); + StashedThreadContext.run(client, () -> client.get(request, ActionListener.wrap( + response -> { + if (response.isExists() == false) { + log.error("TIF job[{}] does not exist in an index[{}]", name, SecurityAnalyticsPlugin.JOB_INDEX_NAME); + listener.onFailure(new ResourceNotFoundException("name")); + } + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getSourceAsBytesRef() + ); + listener.onResponse(TIFJobParameter.PARSER.parse(parser, null)); + }, e -> { + log.error("Failed to fetch tif job document " + name, e); + listener.onFailure(e); + }))); } /** * Put tifJobParameter in an index {@code TIFJobExtension.JOB_INDEX_NAME} * * @param tifJobParameter the tifJobParameter - * @param listener the listener + * @param listener the listener */ - public void saveTIFJobParameter(final TIFJobParameter tifJobParameter, final ActionListener listener) { + public void saveTIFJobParameter(final TIFJobParameter tifJobParameter, final ActionListener listener) { tifJobParameter.setLastUpdateTime(Instant.now()); StashedThreadContext.run(client, () -> { try { diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java index 13db6235d..1d8d8643f 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/TIFJobRunner.java @@ -109,72 +109,82 @@ public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionC * @param jobParameter job parameter */ protected Runnable updateJobRunner(final ScheduledJobParameter jobParameter) { - return () -> { - Optional lockModel = lockService.acquireLock( - jobParameter.getName(), - TIFLockService.LOCK_DURATION_IN_SECONDS - ); - if (lockModel.isEmpty()) { - log.error("Failed to update. Another processor is holding a lock for job parameter[{}]", jobParameter.getName()); - return; - } - - LockModel lock = lockModel.get(); - try { - updateJobParameter(jobParameter, lockService.getRenewLockRunnable(new AtomicReference<>(lock))); - } catch (Exception e) { - log.error("Failed to update job parameter[{}]", jobParameter.getName(), e); - } finally { - lockService.releaseLock(lock); - } - }; + return () -> lockService.acquireLock( + jobParameter.getName(), + TIFLockService.LOCK_DURATION_IN_SECONDS, + ActionListener.wrap(lock -> { + updateJobParameter(jobParameter, lockService.getRenewLockRunnable(new AtomicReference<>(lock)), + ActionListener.wrap( + r -> lockService.releaseLock(lock), + e -> { + log.error("Failed to update job parameter " + jobParameter.getName(), e); + lockService.releaseLock(lock); + } + )); + }, e -> { + log.error("Failed to update. Another processor is holding a lock for job parameter[{}]", jobParameter.getName()); + }) + ); } - protected void updateJobParameter(final ScheduledJobParameter jobParameter, final Runnable renewLock) throws IOException { - TIFJobParameter jobSchedulerParameter = jobSchedulerParameterService.getJobParameter(jobParameter.getName()); - /** - * If delete request comes while update task is waiting on a queue for other update tasks to complete, - * because update task for this jobSchedulerParameter didn't acquire a lock yet, delete request is processed. - * When it is this jobSchedulerParameter's turn to run, it will find that the jobSchedulerParameter is deleted already. - * Therefore, we stop the update process when data source does not exist. - */ - if (jobSchedulerParameter == null) { - log.info("Job parameter[{}] does not exist", jobParameter.getName()); - return; - } + protected void updateJobParameter(final ScheduledJobParameter jobParameter, final Runnable renewLock, ActionListener listener) { + jobSchedulerParameterService.getJobParameter(jobParameter.getName(), ActionListener.wrap( + jobSchedulerParameter -> { + /** + * If delete request comes while update task is waiting on a queue for other update tasks to complete, + * because update task for this jobSchedulerParameter didn't acquire a lock yet, delete request is processed. + * When it is this jobSchedulerParameter's turn to run, it will find that the jobSchedulerParameter is deleted already. + * Therefore, we stop the update process when data source does not exist. + */ + if (jobSchedulerParameter == null) { + log.info("Job parameter[{}] does not exist", jobParameter.getName()); + return; + } - if (TIFJobState.AVAILABLE.equals(jobSchedulerParameter.getState()) == false) { - log.error("Invalid jobSchedulerParameter state. Expecting {} but received {}", TIFJobState.AVAILABLE, jobSchedulerParameter.getState()); - jobSchedulerParameter.disable(); - jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); - jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter, null); - return; - } - // create new TIF data and delete old ones - List oldIndices = new ArrayList<>(jobSchedulerParameter.getIndices()); - jobSchedulerUpdateService.createThreatIntelFeedData(jobSchedulerParameter, renewLock, new ActionListener<>() { - @Override - public void onResponse(ThreatIntelIndicesResponse response) { - if (response.isAcknowledged()) { - List newFeedIndices = response.getIndices(); - jobSchedulerUpdateService.deleteAllTifdIndices(oldIndices, newFeedIndices); - if (false == newFeedIndices.isEmpty()) { - detectorThreatIntelService.updateDetectorsWithLatestThreatIntelRules(); + if (TIFJobState.AVAILABLE.equals(jobSchedulerParameter.getState()) == false) { + log.error("Invalid jobSchedulerParameter state. Expecting {} but received {}", TIFJobState.AVAILABLE, jobSchedulerParameter.getState()); + jobSchedulerParameter.disable(); + jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter, ActionListener.wrap( + r-> {}, e -> log.error("Failed to update job scheduler parameter in Threat intel feed update job") + )); } - } else { - log.error("Failed to update jobSchedulerParameter for {}", jobSchedulerParameter.getName()); - jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); - jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter, null); - } - } - @Override - public void onFailure(Exception e) { - log.error("Failed to update jobSchedulerParameter for {}", jobSchedulerParameter.getName(), e); - jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); - jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter, null); - } - }); + // create new TIF data and delete old ones + List oldIndices = new ArrayList<>(jobSchedulerParameter.getIndices()); + jobSchedulerUpdateService.createThreatIntelFeedData(jobSchedulerParameter, renewLock, new ActionListener<>() { + @Override + public void onResponse(ThreatIntelIndicesResponse response) { + if (response.isAcknowledged()) { + List newFeedIndices = response.getIndices(); + jobSchedulerUpdateService.deleteAllTifdIndices(oldIndices, newFeedIndices); + if (false == newFeedIndices.isEmpty()) { + detectorThreatIntelService.updateDetectorsWithLatestThreatIntelRules(); + } + } else { + log.error("Failed to update jobSchedulerParameter for {}", jobSchedulerParameter.getName()); + jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter, ActionListener.wrap( + r-> {}, e -> log.error("Failed to update job scheduler parameter in Threat intel feed update job") + )); + } + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to update jobSchedulerParameter for {}", jobSchedulerParameter.getName(), e); + jobSchedulerParameter.getUpdateStats().setLastFailedAt(Instant.now()); + jobSchedulerParameterService.updateJobSchedulerParameter(jobSchedulerParameter, ActionListener.wrap( + r-> {}, ex -> log.error("Failed to update job scheduler parameter in Threat intel feed update job") + )); + } + }); + listener.onResponse(null); + }, + e -> { + listener.onFailure(e); + } + )); } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index e6dea9947..883bf8ee7 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -717,7 +717,6 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List private void addThreatIntelBasedDocLevelQueries(Detector detector, ActionListener> listener) { try { - if (detector.getThreatIntelEnabled()) { log.debug("threat intel enabled for detector {} . adding threat intel based doc level queries.", detector.getName()); List iocFieldsList = logTypeService.getIocFieldsList(detector.getDetectorType()); @@ -730,8 +729,7 @@ private void addThreatIntelBasedDocLevelQueries(Detector detector, ActionListene listener.onResponse(List.of()); } } catch (Exception e) { - // not failing detector creation if any fatal exception occurs during doc level query creation from threat intel feed data - log.error("Failed to convert threat intel feed to doc level query. Proceeding with detector creation", e); + log.error("Failed to add threat intel based doc level queries"); listener.onFailure(e); } } diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java index 0643b34d7..3b7b36503 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchDetectorAction.java @@ -6,30 +6,25 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - -import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; - import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.commons.authuser.User; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.cluster.service.ClusterService; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.securityanalytics.action.SearchDetectorAction; import org.opensearch.securityanalytics.action.SearchDetectorRequest; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.threatIntel.action.TransportPutTIFJobAction; import org.opensearch.securityanalytics.util.DetectorIndices; -import org.opensearch.threadpool.ThreadPool; - import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.util.concurrent.CountDownLatch; - import static org.opensearch.securityanalytics.util.DetectorUtils.getEmptySearchResponse; public class TransportSearchDetectorAction extends HandledTransportAction implements SecureTransportAction { diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsPluginTransportIT.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsPluginTransportIT.java deleted file mode 100644 index 688df56a0..000000000 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsPluginTransportIT.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.securityanalytics; - -import org.junit.Assert; -import org.opensearch.action.admin.cluster.node.info.NodeInfo; -import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; -import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; -import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; -import org.opensearch.plugins.PluginInfo; -import org.opensearch.test.OpenSearchIntegTestCase; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/*public class SecurityAnalyticsPluginTransportIT extends OpenSearchIntegTestCase { - - public void testPluginsAreInstalled() { - NodesInfoRequest nodesInfoRequest = new NodesInfoRequest(); - nodesInfoRequest.addMetric(NodesInfoRequest.Metric.PLUGINS.metricName()); - NodesInfoResponse nodesInfoResponse = OpenSearchIntegTestCase.client().admin().cluster().nodesInfo(nodesInfoRequest) - .actionGet(); - List pluginInfos = nodesInfoResponse.getNodes().stream() - .flatMap((Function>) nodeInfo -> nodeInfo.getInfo(PluginsAndModules.class) - .getPluginInfos().stream()).collect(Collectors.toList()); - Assert.assertTrue(pluginInfos.stream().anyMatch(pluginInfo -> pluginInfo.getName() - .equals("opensearch-security-analytics"))); - } -}*/ \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java index 4b6423a3e..7a95e746f 100644 --- a/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java @@ -41,9 +41,12 @@ public void testAcquireLock_whenValidInput_thenSucceed() { public void testAcquireLock_whenCalled_thenNotBlocked() { long expectedDurationInMillis = 1000; Instant before = Instant.now(); - assertTrue(threatIntelLockService.acquireLock(null, null).isEmpty()); - Instant after = Instant.now(); - assertTrue(after.toEpochMilli() - before.toEpochMilli() < expectedDurationInMillis); + threatIntelLockService.acquireLock(null, null, ActionListener.wrap( + r -> fail("Should not have been blocked"), e -> { + Instant after = Instant.now(); + assertTrue(after.toEpochMilli() - before.toEpochMilli() < expectedDurationInMillis); + } + )); } public void testReleaseLock_whenValidInput_thenSucceed() { From db025ce69c0201798b9e862a9156a658d2f2d241 Mon Sep 17 00:00:00 2001 From: Megha Goyal <56077967+goyamegh@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:44:56 -0800 Subject: [PATCH 06/12] Fixing hanging tasks for correlations (#874) Signed-off-by: Megha Goyal --- .../correlation/VectorEmbeddingsEngine.java | 8 +++++++- .../transport/TransportCorrelateFindingAction.java | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java index 0f9866766..cab8798f2 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java @@ -7,6 +7,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.cluster.routing.Preference; import org.opensearch.core.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; @@ -84,6 +85,11 @@ public void onResponse(SearchResponse response) { correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } + if (response.getHits().getHits().length == 0) { + correlateFindingAction.onFailures( + new ResourceNotFoundException("Failed to find hits in metadata index for finding id {}", finding.getId())); + } + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); long counter = Long.parseLong(hitSource.get("counter").toString()); @@ -125,7 +131,7 @@ public void onResponse(MultiSearchResponse items) { continue; } - long totalHits = response.getResponse().getHits().getTotalHits().value; + long totalHits = response.getResponse().getHits().getHits().length; totalNeighbors += totalHits; for (int idx = 0; idx < totalHits; ++idx) { diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index e79af28d3..63c31f99b 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -8,6 +8,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.cluster.routing.Preference; import org.opensearch.core.action.ActionListener; import org.opensearch.action.ActionRequest; @@ -517,6 +518,11 @@ public void onFailure(Exception e) { client.search(searchRequest, new ActionListener<>() { @Override public void onResponse(SearchResponse response) { + if (response.getHits().getHits().length == 0) { + onFailures(new ResourceNotFoundException( + "Failed to find hits in metadata index for finding id {}", request.getFinding().getId())); + } + String id = response.getHits().getHits()[0].getId(); Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); @@ -655,6 +661,7 @@ public void onOperation() { } public void onFailures(Exception t) { + log.error("Exception occurred while processing correlations", t); if (counter.compareAndSet(false, true)) { finishHim(t); } From f4ee7bb9118a35ff4706a9e36abb9f2b042c069f Mon Sep 17 00:00:00 2001 From: Joanne Wang Date: Wed, 6 Mar 2024 05:34:20 -0800 Subject: [PATCH 07/12] Add throw for empty strings in rules with modifier contains, startwith, and endswith (#860) * add validation for empty strings with contains, startswith and endswith modifiers Signed-off-by: Joanne Wang * throw exception if empty string with contains, startswith, or endswith Signed-off-by: Joanne Wang * change var name Signed-off-by: Joanne Wang * add modifiers to log Signed-off-by: Joanne Wang --------- Signed-off-by: Joanne Wang --- .../rules/objects/SigmaDetectionItem.java | 10 ++- .../rules/backend/QueryBackendTests.java | 72 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionItem.java b/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionItem.java index a334ca758..c74bd9177 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionItem.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/objects/SigmaDetectionItem.java @@ -18,6 +18,7 @@ import org.opensearch.securityanalytics.rules.modifiers.SigmaModifierFacade; import org.opensearch.securityanalytics.rules.modifiers.SigmaValueModifier; import org.opensearch.securityanalytics.rules.types.SigmaNull; +import org.opensearch.securityanalytics.rules.types.SigmaString; import org.opensearch.securityanalytics.rules.types.SigmaType; import org.opensearch.securityanalytics.rules.types.SigmaTypeFacade; import org.opensearch.securityanalytics.rules.utils.AnyOneOf; @@ -111,7 +112,14 @@ public static SigmaDetectionItem fromMapping(String key, Either> List sigmaTypes = new ArrayList<>(); for (T v: values) { - sigmaTypes.add(SigmaTypeFacade.sigmaType(v)); + SigmaType sigmaType = SigmaTypeFacade.sigmaType(v); + // throws an error if sigmaType is an empty string and the modifier is "contains" or "startswith" or "endswith" + boolean invalidModifierWithEmptyString = modifierIds.contains("contains") || modifierIds.contains("startswith") || modifierIds.contains("endswith"); + if (sigmaType.getClass().equals(SigmaString.class) && v.toString().isEmpty() && invalidModifierWithEmptyString) { + throw new SigmaValueError("Cannot create rule with empty string and given modifier(s): " + modifierIds); + } else { + sigmaTypes.add(sigmaType); + } } return new SigmaDetectionItem(field, modifiers, sigmaTypes, null, null, true); diff --git a/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java b/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java index aff11d913..3f8196d3d 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java @@ -907,6 +907,78 @@ public void testConvertUnboundValuesAsWildcard() throws IOException, SigmaError Assert.assertEquals("((mappedA: \"value1\") OR (mappedA: \"value2\") OR (mappedA: \"value3\")) OR (test*)", queries.get(0).toString()); } + public void testConvertSkipEmptyStringStartsWithModifier() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + Assert.assertThrows(SigmaValueError.class, () -> { + queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA1|startswith: \n" + + " - value1\n" + + " - value2\n" + + " - ''\n" + + " condition: sel", false)); + }); + } + + public void testConvertSkipEmptyStringEndsWithModifier() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + Assert.assertThrows(SigmaValueError.class, () -> { + queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA1|endswith: \n" + + " - value1\n" + + " - value2\n" + + " - ''\n" + + " condition: sel", false)); + }); + } + + public void testConvertSkipEmptyStringContainsModifier() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + Assert.assertThrows(SigmaValueError.class, () -> { + queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel:\n" + + " fieldA1|contains: \n" + + " - value1\n" + + " - value2\n" + + " - ''\n" + + " condition: sel", false)); + }); + } + private OSQueryBackend testBackend() throws IOException { return new OSQueryBackend(testFieldMapping, false, true); } From ec0657d74a3b147f304e5985250f0e3d8e0e3e4b Mon Sep 17 00:00:00 2001 From: Megha Goyal <56077967+goyamegh@users.noreply.github.com> Date: Wed, 6 Mar 2024 05:54:18 -0800 Subject: [PATCH 08/12] Refactor invocation of Action listeners in correlations (#880) * Refactor invocation of Action listeners in correlations Signed-off-by: Megha Goyal * Close hanging tasks in correlations workflow Signed-off-by: Megha Goyal * Logging finding id and monitor id in error logs Signed-off-by: Megha Goyal --------- Signed-off-by: Megha Goyal --- .../correlation/JoinEngine.java | 541 ++++++------ .../correlation/VectorEmbeddingsEngine.java | 817 ++++++++---------- .../logtype/LogTypeService.java | 2 +- .../TransportCorrelateFindingAction.java | 601 +++++-------- .../util/CorrelationIndices.java | 8 +- 5 files changed, 808 insertions(+), 1161 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index cfff7da26..b33c4d43b 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -5,7 +5,6 @@ package org.opensearch.securityanalytics.correlation; import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.Triple; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; @@ -16,7 +15,6 @@ import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; @@ -91,7 +89,7 @@ public void onSearchDetectorResponse(Detector detector, Finding finding) { onAutoCorrelations(detector, finding, Map.of()); } } catch (IOException ex) { - correlateFindingAction.onFailures(ex); + onFailure(ex); } } @@ -114,102 +112,88 @@ private void generateAutoCorrelations(Detector detector, Finding finding) throws SearchRequest request = new SearchRequest(); request.source(searchSourceBuilder); - logTypeService.searchLogTypes(request, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - MultiSearchRequest mSearchRequest = new MultiSearchRequest(); - SearchHit[] logTypes = response.getHits().getHits(); - List logTypeNames = new ArrayList<>(); - for (SearchHit logType: logTypes) { - String logTypeName = logType.getSourceAsMap().get("name").toString(); - logTypeNames.add(logTypeName); - - RangeQueryBuilder queryBuilder = QueryBuilders.rangeQuery("timestamp") - .gte(findingTimestamp - corrTimeWindow) - .lte(findingTimestamp + corrTimeWindow); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.size(10000); - searchSourceBuilder.fetchField("queries"); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName)); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - mSearchRequest.add(searchRequest); - } + logTypeService.searchLogTypes(request, ActionListener.wrap(response -> { + MultiSearchRequest mSearchRequest = new MultiSearchRequest(); + SearchHit[] logTypes = response.getHits().getHits(); + List logTypeNames = new ArrayList<>(); + for (SearchHit logType: logTypes) { + String logTypeName = logType.getSourceAsMap().get("name").toString(); + logTypeNames.add(logTypeName); + + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery("timestamp") + .gte(findingTimestamp - corrTimeWindow) + .lte(findingTimestamp + corrTimeWindow); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(rangeQueryBuilder); + sourceBuilder.size(10000); + sourceBuilder.fetchField("queries"); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName)); + searchRequest.source(sourceBuilder); + searchRequest.preference(Preference.PRIMARY_FIRST.type()); + mSearchRequest.add(searchRequest); + } + + if (!mSearchRequest.requests().isEmpty()) { + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); - if (!mSearchRequest.requests().isEmpty()) { - client.multiSearch(mSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - - Map> autoCorrelationsMap = new HashMap<>(); - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; + Map> autoCorrelationsMap = new HashMap<>(); + int idx = 0; + for (MultiSearchResponse.Item item : responses) { + if (item.isFailure()) { + log.info(item.getFailureMessage()); + continue; + } + String logTypeName = logTypeNames.get(idx); + + SearchHit[] findings = item.getResponse().getHits().getHits(); + + for (SearchHit foundFinding : findings) { + if (!foundFinding.getId().equals(finding.getId())) { + Set findingTags = new HashSet<>(); + List> queries = (List>) foundFinding.getSourceAsMap().get("queries"); + for (Map query : queries) { + List queryTags = (List) query.get("tags"); + findingTags.addAll(queryTags.stream().filter(queryTag -> queryTag.startsWith("attack.")).collect(Collectors.toList())); } - String logTypeName = logTypeNames.get(idx); - - SearchHit[] findings = response.getResponse().getHits().getHits(); - - for (SearchHit foundFinding : findings) { - if (!foundFinding.getId().equals(finding.getId())) { - Set findingTags = new HashSet<>(); - List> queries = (List>) foundFinding.getSourceAsMap().get("queries"); - for (Map query : queries) { - List queryTags = (List) query.get("tags"); - findingTags.addAll(queryTags.stream().filter(queryTag -> queryTag.startsWith("attack.")).collect(Collectors.toList())); - } - - boolean canCorrelate = false; - for (String tag: tags) { - if (findingTags.contains(tag)) { - canCorrelate = true; - break; - } - } - - Set foundIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, findingTags); - for (String validIntrusionSet: validIntrusionSets) { - if (foundIntrusionSets.contains(validIntrusionSet)) { - canCorrelate = true; - break; - } - } - - if (canCorrelate) { - if (autoCorrelationsMap.containsKey(logTypeName)) { - autoCorrelationsMap.get(logTypeName).add(foundFinding.getId()); - } else { - List autoCorrelatedFindings = new ArrayList<>(); - autoCorrelatedFindings.add(foundFinding.getId()); - autoCorrelationsMap.put(logTypeName, autoCorrelatedFindings); - } - } + + boolean canCorrelate = false; + for (String tag: tags) { + if (findingTags.contains(tag)) { + canCorrelate = true; + break; } } - ++idx; - } - onAutoCorrelations(detector, finding, autoCorrelationsMap); - } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } - } + Set foundIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, findingTags); + for (String validIntrusionSet: validIntrusionSets) { + if (foundIntrusionSets.contains(validIntrusionSet)) { + canCorrelate = true; + break; + } + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + if (canCorrelate) { + if (autoCorrelationsMap.containsKey(logTypeName)) { + autoCorrelationsMap.get(logTypeName).add(foundFinding.getId()); + } else { + List autoCorrelatedFindings = new ArrayList<>(); + autoCorrelatedFindings.add(foundFinding.getId()); + autoCorrelationsMap.put(logTypeName, autoCorrelatedFindings); + } + } + } + } + ++idx; + } + onAutoCorrelations(detector, finding, autoCorrelationsMap); + }, this::onFailure)); + } else { + onFailure(new OpenSearchStatusException("Empty findings for all log types", RestStatus.INTERNAL_SERVER_ERROR)); } - }); + }, this::onFailure)); } private void onAutoCorrelations(Detector detector, Finding finding, Map> autoCorrelations) { @@ -231,39 +215,34 @@ private void onAutoCorrelations(Detector detector, Finding finding, Map() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - Iterator hits = response.getHits().iterator(); - List correlationRules = new ArrayList<>(); - while (hits.hasNext()) { - try { - SearchHit hit = hits.next(); - - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() - ); - - CorrelationRule rule = CorrelationRule.parse(xcp, hit.getId(), hit.getVersion()); - correlationRules.add(rule); - } catch (IOException e) { - correlateFindingAction.onFailures(e); - } - } + Iterator hits = response.getHits().iterator(); + List correlationRules = new ArrayList<>(); + while (hits.hasNext()) { + SearchHit hit = hits.next(); - getValidDocuments(detectorType, indices, correlationRules, relatedDocIds, autoCorrelations); - } + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString()); - @Override - public void onFailure(Exception e) { + CorrelationRule rule = CorrelationRule.parse(xcp, hit.getId(), hit.getVersion()); + correlationRules.add(rule); + } + getValidDocuments(detectorType, indices, correlationRules, relatedDocIds, autoCorrelations); + }, e -> { + try { + log.error("[CORRELATIONS] Exception encountered while searching correlation rule index for finding id {}", + finding.getId(), e); getValidDocuments(detectorType, indices, List.of(), List.of(), autoCorrelations); + } catch (Exception ex) { + onFailure(ex); } - }); + })); } /** @@ -306,84 +285,72 @@ private void getValidDocuments(String detectorType, List indices, List() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - List filteredCorrelationRules = new ArrayList<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + List filteredCorrelationRules = new ArrayList<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + continue; + } - if (response.getResponse().getHits().getTotalHits().value > 0L) { - filteredCorrelationRules.add(new FilteredCorrelationRule(validCorrelationRules.get(idx), - response.getResponse().getHits().getHits(), validFields.get(idx))); - } - ++idx; + if (response.getResponse().getHits().getHits().length > 0L) { + filteredCorrelationRules.add(new FilteredCorrelationRule(validCorrelationRules.get(idx), + response.getResponse().getHits().getHits(), validFields.get(idx))); } + ++idx; + } - Map> categoryToQueriesMap = new HashMap<>(); - Map categoryToTimeWindowMap = new HashMap<>(); - for (FilteredCorrelationRule rule: filteredCorrelationRules) { - List queries = rule.correlationRule.getCorrelationQueries(); - Long timeWindow = rule.correlationRule.getCorrTimeWindow(); - - for (CorrelationQuery query: queries) { - List correlationQueries; - if (categoryToQueriesMap.containsKey(query.getCategory())) { - correlationQueries = categoryToQueriesMap.get(query.getCategory()); - } else { - correlationQueries = new ArrayList<>(); - } - if (categoryToTimeWindowMap.containsKey(query.getCategory())) { - categoryToTimeWindowMap.put(query.getCategory(), Math.max(timeWindow, categoryToTimeWindowMap.get(query.getCategory()))); - } else { - categoryToTimeWindowMap.put(query.getCategory(), timeWindow); - } + Map> categoryToQueriesMap = new HashMap<>(); + Map categoryToTimeWindowMap = new HashMap<>(); + for (FilteredCorrelationRule rule: filteredCorrelationRules) { + List queries = rule.correlationRule.getCorrelationQueries(); + Long timeWindow = rule.correlationRule.getCorrTimeWindow(); - if (query.getField() == null) { - correlationQueries.add(query); - } else { - SearchHit[] hits = rule.filteredDocs; - StringBuilder qb = new StringBuilder(query.getField()).append(":("); - for (int i = 0; i < hits.length; ++i) { - String value = hits[i].field(rule.field).getValue(); - qb.append(value); - if (i < hits.length-1) { - qb.append(" OR "); - } else { - qb.append(")"); - } - } - if (query.getQuery() != null) { - qb.append(" AND ").append(query.getQuery()); + for (CorrelationQuery query: queries) { + List correlationQueries; + if (categoryToQueriesMap.containsKey(query.getCategory())) { + correlationQueries = categoryToQueriesMap.get(query.getCategory()); + } else { + correlationQueries = new ArrayList<>(); + } + if (categoryToTimeWindowMap.containsKey(query.getCategory())) { + categoryToTimeWindowMap.put(query.getCategory(), Math.max(timeWindow, categoryToTimeWindowMap.get(query.getCategory()))); + } else { + categoryToTimeWindowMap.put(query.getCategory(), timeWindow); + } + + if (query.getField() == null) { + correlationQueries.add(query); + } else { + SearchHit[] hits = rule.filteredDocs; + StringBuilder qb = new StringBuilder(query.getField()).append(":("); + for (int i = 0; i < hits.length; ++i) { + String value = hits[i].field(rule.field).getValue(); + qb.append(value); + if (i < hits.length-1) { + qb.append(" OR "); + } else { + qb.append(")"); } - correlationQueries.add(new CorrelationQuery(query.getIndex(), qb.toString(), query.getCategory(), null)); } - categoryToQueriesMap.put(query.getCategory(), correlationQueries); + if (query.getQuery() != null) { + qb.append(" AND ").append(query.getQuery()); + } + correlationQueries.add(new CorrelationQuery(query.getIndex(), qb.toString(), query.getCategory(), null)); } + categoryToQueriesMap.put(query.getCategory(), correlationQueries); } - searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap, - filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), - autoCorrelations - ); - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); } - }); + searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap, + filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), + autoCorrelations + ); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of()); - } + getTimestampFeature(detectorType, List.of(), autoCorrelations); } } @@ -415,50 +382,38 @@ private void searchFindingsByTimestamp(String detectorType, Map() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - Map relatedDocsMap = new HashMap<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } - - List relatedDocIds = new ArrayList<>(); - SearchHit[] hits = response.getResponse().getHits().getHits(); - for (SearchHit hit : hits) { - relatedDocIds.addAll(hit.getFields().get("correlated_doc_ids").getValues().stream() - .map(Object::toString).collect(Collectors.toList())); - } + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + Map relatedDocsMap = new HashMap<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + continue; + } - List correlationQueries = categoryToQueriesPairs.get(idx).getValue(); - List indices = correlationQueries.stream().map(CorrelationQuery::getIndex).collect(Collectors.toList()); - List queries = correlationQueries.stream().map(CorrelationQuery::getQuery).collect(Collectors.toList()); - relatedDocsMap.put(categoryToQueriesPairs.get(idx).getKey(), - new DocSearchCriteria( - indices, - queries, - relatedDocIds)); - ++idx; + List relatedDocIds = new ArrayList<>(); + SearchHit[] hits = response.getResponse().getHits().getHits(); + for (SearchHit hit : hits) { + relatedDocIds.addAll(hit.getFields().get("correlated_doc_ids").getValues().stream() + .map(Object::toString).collect(Collectors.toList())); } - searchDocsWithFilterKeys(detectorType, relatedDocsMap, categoryToTimeWindowMap, correlationRules, autoCorrelations); - } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + List correlationQueries = categoryToQueriesPairs.get(idx).getValue(); + List indices = correlationQueries.stream().map(CorrelationQuery::getIndex).collect(Collectors.toList()); + List queries = correlationQueries.stream().map(CorrelationQuery::getQuery).collect(Collectors.toList()); + relatedDocsMap.put(categoryToQueriesPairs.get(idx).getKey(), + new DocSearchCriteria( + indices, + queries, + relatedDocIds)); + ++idx; } - }); + searchDocsWithFilterKeys(detectorType, relatedDocsMap, categoryToTimeWindowMap, correlationRules, autoCorrelations); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); - } + getTimestampFeature(detectorType, correlationRules, autoCorrelations); } } @@ -492,42 +447,30 @@ private void searchDocsWithFilterKeys(String detectorType, Map() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - Map> filteredRelatedDocIds = new HashMap<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } + client.multiSearch(mSearchRequest, ActionListener.wrap( items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + Map> filteredRelatedDocIds = new HashMap<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + continue; + } - SearchHit[] hits = response.getResponse().getHits().getHits(); - List docIds = new ArrayList<>(); + SearchHit[] hits = response.getResponse().getHits().getHits(); + List docIds = new ArrayList<>(); - for (SearchHit hit : hits) { - docIds.add(hit.getId()); - } - filteredRelatedDocIds.put(categories.get(idx), docIds); - ++idx; + for (SearchHit hit : hits) { + docIds.add(hit.getId()); } - getCorrelatedFindings(detectorType, filteredRelatedDocIds, categoryToTimeWindowMap, correlationRules, autoCorrelations); + filteredRelatedDocIds.put(categories.get(idx), docIds); + ++idx; } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); + getCorrelatedFindings(detectorType, filteredRelatedDocIds, categoryToTimeWindowMap, correlationRules, autoCorrelations); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); - } + getTimestampFeature(detectorType, correlationRules, autoCorrelations); } } @@ -565,59 +508,59 @@ private void getCorrelatedFindings(String detectorType, Map } if (!mSearchRequest.requests().isEmpty()) { - client.multiSearch(mSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - Map> correlatedFindings = new HashMap<>(); - - int idx = 0; - for (MultiSearchResponse.Item response : responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - ++idx; - continue; - } - - SearchHit[] hits = response.getResponse().getHits().getHits(); - List findings = new ArrayList<>(); + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + Map> correlatedFindings = new HashMap<>(); + + int idx = 0; + for (MultiSearchResponse.Item response : responses) { + if (response.isFailure()) { + log.info(response.getFailureMessage()); + ++idx; + continue; + } - for (SearchHit hit : hits) { - findings.add(hit.getId()); - } + SearchHit[] hits = response.getResponse().getHits().getHits(); + List findings = new ArrayList<>(); - if (!findings.isEmpty()) { - correlatedFindings.put(categories.get(idx), findings); - } - ++idx; + for (SearchHit hit : hits) { + findings.add(hit.getId()); } - for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { - if (correlatedFindings.containsKey(autoCorrelation.getKey())) { - Set alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey())); - alreadyCorrelatedFindings.addAll(autoCorrelation.getValue()); - correlatedFindings.put(autoCorrelation.getKey(), new ArrayList<>(alreadyCorrelatedFindings)); - } else { - correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); - } + if (!findings.isEmpty()) { + correlatedFindings.put(categories.get(idx), findings); } - correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + ++idx; } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { + if (correlatedFindings.containsKey(autoCorrelation.getKey())) { + Set alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey())); + alreadyCorrelatedFindings.addAll(autoCorrelation.getValue()); + correlatedFindings.put(autoCorrelation.getKey(), new ArrayList<>(alreadyCorrelatedFindings)); + } else { + correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); + } } - }); + correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + }, this::onFailure)); } else { - if (!autoCorrelations.isEmpty()) { - correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); - } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); - } + getTimestampFeature(detectorType, correlationRules, autoCorrelations); } } + private void getTimestampFeature(String detectorType, List correlationRules, Map> autoCorrelations) { + if (!autoCorrelations.isEmpty()) { + correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); + } else { + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); + } + } + + private void onFailure(Exception e) { + correlateFindingAction.onFailures(e); + } + static class DocSearchCriteria { List indices; List queries; diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java index cab8798f2..86fc70bbd 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java @@ -11,13 +11,10 @@ import org.opensearch.cluster.routing.Preference; import org.opensearch.core.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; @@ -32,11 +29,9 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryBuilder; import org.opensearch.securityanalytics.model.CustomLogType; -import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction; import org.opensearch.securityanalytics.util.CorrelationIndices; -import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.Map; @@ -62,213 +57,205 @@ public VectorEmbeddingsEngine(Client client, TimeValue indexTimeout, long corrTi } public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List correlatedFindings, float timestampFeature, List correlationRules, Map logTypes) { + SearchRequest searchRequest = getSearchMetadataIndexRequest(detectorType, finding, logTypes); Map tags = logTypes.get(detectorType).getTags(); String correlationId = tags.get("correlation_id").toString(); long findingTimestamp = finding.getTimestamp().toEpochMilli(); - MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( - "root", true - ); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - if (response.getHits().getHits().length == 0) { - correlateFindingAction.onFailures( - new ResourceNotFoundException("Failed to find hits in metadata index for finding id {}", finding.getId())); - } - - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - long counter = Long.parseLong(hitSource.get("counter").toString()); - - MultiSearchRequest mSearchRequest = new MultiSearchRequest(); - - for (String correlatedFinding: correlatedFindings) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.matchQuery( - "finding1", correlatedFinding - )).must(QueryBuilders.matchQuery( - "finding2", "" - ))/*.must(QueryBuilders.matchQuery( - "counter", counter - ))*/; - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - mSearchRequest.add(searchRequest); - } - - client.multiSearch(mSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - MultiSearchResponse.Item[] responses = items.getResponses(); - BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - long prevCounter = -1L; - long totalNeighbors = 0L; - for (MultiSearchResponse.Item response: responses) { - if (response.isFailure()) { - log.info(response.getFailureMessage()); - continue; - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - long totalHits = response.getResponse().getHits().getHits().length; - totalNeighbors += totalHits; + if (response.getHits().getHits().length == 0) { + onFailure( + new ResourceNotFoundException("Failed to find hits in metadata index for finding id {}", finding.getId())); + } - for (int idx = 0; idx < totalHits; ++idx) { - SearchHit hit = response.getResponse().getHits().getHits()[idx]; - Map hitSource = hit.getSourceAsMap(); - long neighborCounter = Long.parseLong(hitSource.get("counter").toString()); - String correlatedFinding = hitSource.get("finding1").toString(); + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); + long counter = Long.parseLong(hitSource.get("counter").toString()); + + MultiSearchRequest mSearchRequest = new MultiSearchRequest(); + + for (String correlatedFinding: correlatedFindings) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.matchQuery( + "finding1", correlatedFinding + )).must(QueryBuilders.matchQuery( + "finding2", "" + ))/*.must(QueryBuilders.matchQuery( + "counter", counter + ))*/; + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(10000); + SearchRequest request = new SearchRequest(); + request.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); + request.source(searchSourceBuilder); + request.preference(Preference.PRIMARY_FIRST.type()); + + mSearchRequest.add(request); + } - try { - float[] corrVector = new float[3]; - if (counter != prevCounter) { - for (int i = 0; i < 2; ++i) { - corrVector[i] = ((float) counter) - 50.0f; - } + client.multiSearch(mSearchRequest, ActionListener.wrap(items -> { + MultiSearchResponse.Item[] responses = items.getResponses(); + BulkRequest bulkRequest = new BulkRequest(); + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + long prevCounter = -1L; + long totalNeighbors = 0L; + for (MultiSearchResponse.Item item: responses) { + if (item.isFailure()) { + log.info(item.getFailureMessage()); + continue; + } - corrVector[0] = (float) counter; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", counter); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", correlationId); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout); - bulkRequest.add(indexRequest); - } + long totalHits = item.getResponse().getHits().getHits().length; + totalNeighbors += totalHits; - corrVector = new float[3]; - for (int i = 0; i < 2; ++i) { - corrVector[i] = ((float) counter) - 50.0f; - } - corrVector[0] = (2.0f * ((float) counter) - 50.0f) / 2.0f; - corrVector[1] = (2.0f * ((float) neighborCounter) - 50.0f) / 2.0f; - corrVector[2] = timestampFeature; + for (int idx = 0; idx < totalHits; ++idx) { + SearchHit hit = item.getResponse().getHits().getHits()[idx]; + Map sourceAsMap = hit.getSourceAsMap(); + long neighborCounter = Long.parseLong(sourceAsMap.get("counter").toString()); + String correlatedFinding = sourceAsMap.get("finding1").toString(); - XContentBuilder corrBuilder = XContentFactory.jsonBuilder().startObject(); - corrBuilder.field("root", false); - corrBuilder.field("counter", (long) ((2.0f * ((float) counter) - 50.0f) / 2.0f)); - corrBuilder.field("finding1", finding.getId()); - corrBuilder.field("finding2", correlatedFinding); - corrBuilder.field("logType", String.format(Locale.ROOT, "%s-%s", detectorType, logType)); - corrBuilder.field("timestamp", findingTimestamp); - corrBuilder.field("corr_vector", corrVector); - corrBuilder.field("recordType", "finding-finding"); - corrBuilder.field("scoreTimestamp", 0L); - corrBuilder.field("corrRules", correlationRules); - corrBuilder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(corrBuilder) - .timeout(indexTimeout); - bulkRequest.add(indexRequest); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); + try { + float[] corrVector = new float[3]; + if (counter != prevCounter) { + for (int i = 0; i < 2; ++i) { + corrVector[i] = ((float) counter) - 50.0f; } - prevCounter = counter; - } - } - if (totalNeighbors > 0L) { - client.bulk(bulkRequest, new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (response.hasFailures()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Correlation of finding failed", RestStatus.INTERNAL_SERVER_ERROR)); - } - correlateFindingAction.onOperation(); - } + corrVector[0] = (float) counter; + corrVector[2] = timestampFeature; + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", false); + builder.field("counter", counter); + builder.field("finding1", finding.getId()); + builder.field("finding2", ""); + builder.field("logType", correlationId); + builder.field("timestamp", findingTimestamp); + builder.field("corr_vector", corrVector); + builder.field("recordType", "finding"); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) + .source(builder) + .timeout(indexTimeout); + bulkRequest.add(indexRequest); + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } else { - insertOrphanFindings(detectorType, finding, timestampFeature, logTypes); + corrVector = new float[3]; + for (int i = 0; i < 2; ++i) { + corrVector[i] = ((float) counter) - 50.0f; + } + corrVector[0] = (2.0f * ((float) counter) - 50.0f) / 2.0f; + corrVector[1] = (2.0f * ((float) neighborCounter) - 50.0f) / 2.0f; + corrVector[2] = timestampFeature; + + XContentBuilder corrBuilder = XContentFactory.jsonBuilder().startObject(); + corrBuilder.field("root", false); + corrBuilder.field("counter", (long) ((2.0f * ((float) counter) - 50.0f) / 2.0f)); + corrBuilder.field("finding1", finding.getId()); + corrBuilder.field("finding2", correlatedFinding); + corrBuilder.field("logType", String.format(Locale.ROOT, "%s-%s", detectorType, logType)); + corrBuilder.field("timestamp", findingTimestamp); + corrBuilder.field("corr_vector", corrVector); + corrBuilder.field("recordType", "finding-finding"); + corrBuilder.field("scoreTimestamp", 0L); + corrBuilder.field("corrRules", correlationRules); + corrBuilder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) + .source(corrBuilder) + .timeout(indexTimeout); + bulkRequest.add(indexRequest); + } catch (Exception ex) { + onFailure(ex); } + prevCounter = counter; } + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); + if (totalNeighbors > 0L) { + client.bulk(bulkRequest, ActionListener.wrap( bulkResponse -> { + if (bulkResponse.hasFailures()) { + onFailure(new OpenSearchStatusException("Correlation of finding failed", RestStatus.INTERNAL_SERVER_ERROR)); + } + correlateFindingAction.onOperation(); + }, this::onFailure)); + } else { + insertOrphanFindings(detectorType, finding, timestampFeature, logTypes); + } + }, this::onFailure)); + }, this::onFailure)); } public void insertOrphanFindings(String detectorType, Finding finding, float timestampFeature, Map logTypes) { - if (logTypes.get(detectorType) == null) { - log.error("LogTypes Index is missing the detector type {}", detectorType); - correlateFindingAction.onFailures(new OpenSearchStatusException("LogTypes Index is missing the detector type", RestStatus.INTERNAL_SERVER_ERROR)); - } - + SearchRequest searchRequest = getSearchMetadataIndexRequest(detectorType, finding, logTypes); Map tags = logTypes.get(detectorType).getTags(); String correlationId = tags.get("correlation_id").toString(); - long findingTimestamp = finding.getTimestamp().toEpochMilli(); - MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( - "root", true - ); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - try { - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - String id = response.getHits().getHits()[0].getId(); - long counter = Long.parseLong(hitSource.get("counter").toString()); - long timestamp = Long.parseLong(hitSource.get("timestamp").toString()); - if (counter == 0L) { + try { + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); + String id = response.getHits().getHits()[0].getId(); + long counter = Long.parseLong(hitSource.get("counter").toString()); + long timestamp = Long.parseLong(hitSource.get("timestamp").toString()); + if (counter == 0L) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", true); + builder.field("counter", 50L); + builder.field("finding1", ""); + builder.field("finding2", ""); + builder.field("logType", ""); + builder.field("timestamp", findingTimestamp); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + .id(id) + .source(builder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse.status().equals(RestStatus.OK)) { + try { + float[] corrVector = new float[3]; + corrVector[0] = 50.0f; + corrVector[2] = timestampFeature; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + xContentBuilder.field("root", false); + xContentBuilder.field("counter", 50L); + xContentBuilder.field("finding1", finding.getId()); + xContentBuilder.field("finding2", ""); + xContentBuilder.field("logType", correlationId); + xContentBuilder.field("timestamp", findingTimestamp); + xContentBuilder.field("corr_vector", corrVector); + xContentBuilder.field("recordType", "finding"); + xContentBuilder.field("scoreTimestamp", 0L); + xContentBuilder.endObject(); + + indexCorrelatedFindings(xContentBuilder); + } catch (Exception ex) { + onFailure(ex); + } + } else { + onFailure(new OpenSearchStatusException(indexResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + } + }, this::onFailure)); + } else { + if (findingTimestamp - timestamp > corrTimeWindow) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder.field("root", true); builder.field("counter", 50L); @@ -285,308 +272,192 @@ public void onResponse(SearchResponse response) { .timeout(indexTimeout) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.OK)) { - try { - float[] corrVector = new float[3]; - corrVector[0] = 50.0f; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", 50L); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", correlationId); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse.status().equals(RestStatus.OK)) { + correlateFindingAction.onOperation(); + try { + float[] corrVector = new float[3]; + corrVector[0] = 50.0f; + corrVector[2] = timestampFeature; - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } + XContentBuilder contentBuilder = XContentFactory.jsonBuilder().startObject(); + contentBuilder.field("root", false); + contentBuilder.field("counter", 50L); + contentBuilder.field("finding1", finding.getId()); + contentBuilder.field("finding2", ""); + contentBuilder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); + contentBuilder.field("timestamp", findingTimestamp); + contentBuilder.field("corr_vector", corrVector); + contentBuilder.field("recordType", "finding"); + contentBuilder.field("scoreTimestamp", 0L); + contentBuilder.endObject(); + + indexCorrelatedFindings(contentBuilder); + } catch (Exception ex) { + onFailure(ex); } + } else { + onFailure(new OpenSearchStatusException(indexResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); + }, this::onFailure)); } else { - if (findingTimestamp - timestamp > corrTimeWindow) { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", true); - builder.field("counter", 50L); - builder.field("finding1", ""); - builder.field("finding2", ""); - builder.field("logType", ""); - builder.field("timestamp", findingTimestamp); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.OK)) { - correlateFindingAction.onOperation(); - try { - float[] corrVector = new float[3]; - corrVector[0] = 50.0f; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", 50L); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } + float[] query = new float[3]; + for (int i = 0; i < 2; ++i) { + query[i] = (2.0f * ((float) counter) - 50.0f) / 2.0f; + } + query[2] = timestampFeature; + + CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder("corr_vector", query, 100, QueryBuilders.boolQuery() + .mustNot(QueryBuilders.matchQuery( + "finding1", "" + )).mustNot(QueryBuilders.matchQuery( + "finding2", "" + )).filter(QueryBuilders.rangeQuery("timestamp") + .gte(findingTimestamp - corrTimeWindow) + .lte(findingTimestamp + corrTimeWindow))); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(correlationQueryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(1); + SearchRequest request = new SearchRequest(); + request.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); + request.source(searchSourceBuilder); + request.preference(Preference.PRIMARY_FIRST.type()); + + client.search(request, ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } - } - } + long totalHits = searchResponse.getHits().getHits().length; + SearchHit hit = totalHits > 0? searchResponse.getHits().getHits()[0]: null; + long existCounter = 0L; - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } else { - float[] query = new float[3]; - for (int i = 0; i < 2; ++i) { - query[i] = (2.0f * ((float) counter) - 50.0f) / 2.0f; + if (hit != null) { + Map sourceAsMap = searchResponse.getHits().getHits()[0].getSourceAsMap(); + existCounter = Long.parseLong(sourceAsMap.get("counter").toString()); } - query[2] = timestampFeature; - - CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder("corr_vector", query, 100, QueryBuilders.boolQuery() - .mustNot(QueryBuilders.matchQuery( - "finding1", "" - )).mustNot(QueryBuilders.matchQuery( - "finding2", "" - )).filter(QueryBuilders.rangeQuery("timestamp") - .gte(findingTimestamp - corrTimeWindow) - .lte(findingTimestamp + corrTimeWindow))); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(correlationQueryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - correlateFindingAction.onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - long totalHits = response.getHits().getTotalHits().value; - SearchHit hit = totalHits > 0? response.getHits().getHits()[0]: null; - long existCounter = 0L; - - if (hit != null) { - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - existCounter = Long.parseLong(hitSource.get("counter").toString()); + if (totalHits == 0L || existCounter != ((long) (2.0f * ((float) counter) - 50.0f) / 2.0f)) { + try { + float[] corrVector = new float[3]; + for (int i = 0; i < 2; ++i) { + corrVector[i] = ((float) counter) - 50.0f; } + corrVector[0] = (float) counter; + corrVector[2] = timestampFeature; - if (totalHits == 0L || existCounter != ((long) (2.0f * ((float) counter) - 50.0f) / 2.0f)) { - try { - float[] corrVector = new float[3]; - for (int i = 0; i < 2; ++i) { - corrVector[i] = ((float) counter) - 50.0f; - } - corrVector[0] = (float) counter; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", counter); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } - } else { - try { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", true); - builder.field("counter", counter + 50L); - builder.field("finding1", ""); - builder.field("finding2", ""); - builder.field("logType", ""); - builder.field("timestamp", findingTimestamp); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.OK)) { - try { - float[] corrVector = new float[3]; - for (int i = 0; i < 2; ++i) { - corrVector[i] = (float) counter; - } - corrVector[0] = counter + 50.0f; - corrVector[2] = timestampFeature; - - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.field("root", false); - builder.field("counter", counter + 50L); - builder.field("finding1", finding.getId()); - builder.field("finding2", ""); - builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); - builder.field("timestamp", findingTimestamp); - builder.field("corr_vector", corrVector); - builder.field("recordType", "finding"); - builder.field("scoreTimestamp", 0L); - builder.endObject(); - - IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) - .source(builder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - if (response.status().equals(RestStatus.CREATED)) { - correlateFindingAction.onOperation(); - } else { - correlateFindingAction.onFailures(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); - } - } - } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", false); + builder.field("counter", counter); + builder.field("finding1", finding.getId()); + builder.field("finding2", ""); + builder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); + builder.field("timestamp", findingTimestamp); + builder.field("corr_vector", corrVector); + builder.field("recordType", "finding"); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + indexCorrelatedFindings(builder); + } catch (Exception ex) { + onFailure(ex); + } + } else { + try { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("root", true); + builder.field("counter", counter + 50L); + builder.field("finding1", ""); + builder.field("finding2", ""); + builder.field("logType", ""); + builder.field("timestamp", findingTimestamp); + builder.field("scoreTimestamp", 0L); + builder.endObject(); + + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + .id(id) + .source(builder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse.status().equals(RestStatus.OK)) { + try { + float[] corrVector = new float[3]; + for (int i = 0; i < 2; ++i) { + corrVector[i] = (float) counter; } - }); - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); + corrVector[0] = counter + 50.0f; + corrVector[2] = timestampFeature; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject(); + xContentBuilder.field("root", false); + xContentBuilder.field("counter", counter + 50L); + xContentBuilder.field("finding1", finding.getId()); + xContentBuilder.field("finding2", ""); + xContentBuilder.field("logType", Integer.valueOf(logTypes.get(detectorType).getTags().get("correlation_id").toString()).toString()); + xContentBuilder.field("timestamp", findingTimestamp); + xContentBuilder.field("corr_vector", corrVector); + xContentBuilder.field("recordType", "finding"); + xContentBuilder.field("scoreTimestamp", 0L); + xContentBuilder.endObject(); + + indexCorrelatedFindings(xContentBuilder); + } catch (Exception ex) { + onFailure(ex); + } } - } + }, this::onFailure)); + } catch (Exception ex) { + onFailure(ex); } - - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); - } - }); - } + } + }, this::onFailure)); } - } catch (IOException ex) { - correlateFindingAction.onFailures(ex); } + } catch (Exception ex) { + onFailure(ex); } + }, this::onFailure)); + } - @Override - public void onFailure(Exception e) { - correlateFindingAction.onFailures(e); + private void indexCorrelatedFindings(XContentBuilder builder) { + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX) + .source(builder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(indexRequest, ActionListener.wrap(response -> { + if (response.status().equals(RestStatus.CREATED)) { + correlateFindingAction.onOperation(); + } else { + onFailure(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); } - }); + }, this::onFailure)); + } + + private SearchRequest getSearchMetadataIndexRequest(String detectorType, Finding finding, Map logTypes) { + if (logTypes.get(detectorType) == null) { + throw new OpenSearchStatusException("LogTypes Index is missing the detector type", RestStatus.INTERNAL_SERVER_ERROR); + } + + Map tags = logTypes.get(detectorType).getTags(); + MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( + "root", true + ); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(1); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); + searchRequest.source(searchSourceBuilder); + searchRequest.preference(Preference.PRIMARY_FIRST.type()); + return searchRequest; + } + + private void onFailure(Exception e) { + correlateFindingAction.onFailures(e); } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java b/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java index 9036f514d..80ec1270a 100644 --- a/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java +++ b/src/main/java/org/opensearch/securityanalytics/logtype/LogTypeService.java @@ -282,7 +282,7 @@ public void onResponse(SearchResponse response) { if (response.isTimedOut()) { listener.onFailure(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - if (response.getHits().getTotalHits().value > 0) { + if (response.getHits().getHits().length > 0) { listener.onResponse(null); } else { try { diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 63c31f99b..d5e0eed32 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -132,75 +132,50 @@ public TransportCorrelateFindingAction(TransportService transportService, protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { try { PublishFindingsRequest transformedRequest = transformRequest(request); + AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); if (!this.correlationIndices.correlationIndexExists()) { try { - this.correlationIndices.initCorrelationIndex(new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationIndexUpdated(); - if (IndexUtils.correlationIndexUpdated) { - IndexUtils.lastUpdatedCorrelationHistoryIndex = IndexUtils.getIndexNameWithAlias( - clusterService.state(), - CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX - ); - } + this.correlationIndices.initCorrelationIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + IndexUtils.correlationIndexUpdated(); + if (IndexUtils.correlationIndexUpdated) { + IndexUtils.lastUpdatedCorrelationHistoryIndex = IndexUtils.getIndexNameWithAlias( + clusterService.state(), + CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX + ); + } - if (!correlationIndices.correlationMetadataIndexExists()) { - try { - correlationIndices.initCorrelationMetadataIndex(new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationMetadataIndexUpdated(); - - correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (response.hasFailures()) { - log.error(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); - } - AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); - correlateFindingAction.start(); - } + if (!correlationIndices.correlationMetadataIndexExists()) { + try { + correlationIndices.initCorrelationMetadataIndex(ActionListener.wrap(createIndexResponse -> { + if (createIndexResponse.isAcknowledged()) { + IndexUtils.correlationMetadataIndexUpdated(); - @Override - public void onFailure(Exception e) { - log.error(e); - } - }); - } else { - log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); + correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, ActionListener.wrap(bulkResponse -> { + if (bulkResponse.hasFailures()) { + correlateFindingAction.onFailures(new OpenSearchStatusException(createIndexResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); } - } - - @Override - public void onFailure(Exception e) { - - } - }); - } catch (Exception ex) { - onFailure(ex); - } + correlateFindingAction.start(); + }, correlateFindingAction::onFailures)); + } else { + correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, correlateFindingAction::onFailures)); + } catch (Exception ex) { + correlateFindingAction.onFailures(ex); } - } else { - log.error(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } + } else { + correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } - - @Override - public void onFailure(Exception e) { - log.error(e); - } - }); + }, correlateFindingAction::onFailures)); } catch (IOException ex) { - log.error(ex); + correlateFindingAction.onFailures(ex); } } else { - AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); correlateFindingAction.start(); } } catch (IOException e) { @@ -254,39 +229,30 @@ void start() { searchRequest.source(searchSourceBuilder); searchRequest.preference(Preference.PRIMARY_FIRST.type()); - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - SearchHits hits = response.getHits(); - // Detectors Index hits count could be more even if we fetch one - if (hits.getTotalHits().value >= 1 && hits.getHits().length > 0) { - try { - SearchHit hit = hits.getAt(0); - - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() - ); - Detector detector = Detector.docParse(xcp, hit.getId(), hit.getVersion()); - joinEngine.onSearchDetectorResponse(detector, finding); - } catch (IOException e) { - log.error("IOException for request {}", searchRequest.toString(), e); - onFailures(e); - } - } else { - onFailures(new OpenSearchStatusException("detector not found given monitor id", RestStatus.INTERNAL_SERVER_ERROR)); - } + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - @Override - public void onFailure(Exception e) { - onFailures(e); + SearchHits hits = response.getHits(); + if (hits.getHits().length > 0) { + try { + SearchHit hit = hits.getAt(0); + + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + ); + Detector detector = Detector.docParse(xcp, hit.getId(), hit.getVersion()); + joinEngine.onSearchDetectorResponse(detector, finding); + } catch (IOException e) { + log.error("IOException for request {}", searchRequest.toString(), e); + onFailures(e); + } + } else { + onFailures(new OpenSearchStatusException("detector not found given monitor id " + request.getMonitorId(), RestStatus.INTERNAL_SERVER_ERROR)); } - }); + }, this::onFailures)); } else { onFailures(new SecurityAnalyticsException(String.format(Locale.getDefault(), "Detector index %s doesnt exist", Detector.DETECTORS_INDEX), RestStatus.INTERNAL_SERVER_ERROR, new RuntimeException())); } @@ -298,22 +264,14 @@ public void initCorrelationIndex(String detectorType, Map> IndexUtils.updateIndexMapping( CorrelationIndices.CORRELATION_HISTORY_WRITE_INDEX, CorrelationIndices.correlationMappings(), clusterService.state(), client.admin().indices(), - new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse response) { - if (response.isAcknowledged()) { - IndexUtils.correlationIndexUpdated(); - getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); - } else { - onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); - } + ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + IndexUtils.correlationIndexUpdated(); + getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); + } else { + onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }, + }, this::onFailures), true ); } else { @@ -325,199 +283,84 @@ public void onFailure(Exception e) { } public void getTimestampFeature(String detectorType, Map> correlatedFindings, Finding orphanFinding, List correlationRules) { - if (!correlationIndices.correlationMetadataIndexExists()) { - try { - correlationIndices.initCorrelationMetadataIndex(new ActionListener<>() { - @Override - public void onResponse(CreateIndexResponse response) { + try { + if (!correlationIndices.correlationMetadataIndexExists()) { + correlationIndices.initCorrelationMetadataIndex(ActionListener.wrap(response -> { if (response.isAcknowledged()) { IndexUtils.correlationMetadataIndexUpdated(); - correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, new ActionListener<>() { - @Override - public void onResponse(BulkResponse response) { - if (response.hasFailures()) { - log.error(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + correlationIndices.setupCorrelationIndex(indexTimeout, setupTimestamp, ActionListener.wrap(bulkResponse -> { + if (bulkResponse.hasFailures()) { + onFailures(new OpenSearchStatusException(bulkResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + } + + long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); + SearchRequest searchMetadataIndexRequest = getSearchMetadataIndexRequest(); + + client.search(searchMetadataIndexRequest, ActionListener.wrap(searchMetadataResponse -> { + if (searchMetadataResponse.getHits().getHits().length == 0) { + onFailures(new ResourceNotFoundException( + "Failed to find hits in metadata index for finding id {}", request.getFinding().getId())); } - long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - String id = response.getHits().getHits()[0].getId(); - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - - if (findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL > scoreTimestamp) { - try { - XContentBuilder scoreBuilder = XContentFactory.jsonBuilder().startObject(); - scoreBuilder.field("scoreTimestamp", findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL); - scoreBuilder.field("root", false); - scoreBuilder.endObject(); - - IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(scoreBuilder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(scoreIndexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } - - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } - - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - } - - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } + String id = searchMetadataResponse.getHits().getHits()[0].getId(); + Map hitSource = searchMetadataResponse.getHits().getHits()[0].getSourceAsMap(); + long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } catch (Exception ex) { - onFailures(ex); - } - } else { - float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + long newScoreTimestamp = findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL; + if (newScoreTimestamp > scoreTimestamp) { + try { + IndexRequest scoreIndexRequest = getCorrelationMetadataIndexRequest(id, newScoreTimestamp); - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } + client.index(scoreIndexRequest, ActionListener.wrap(indexResponse -> { + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), timestampFeature, logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - timestampFeature, correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature, logTypes); - } + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - @Override - public void onFailure(Exception e) { - onFailures(e); + SearchHit[] hits = searchResponse.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), + new CustomLogType(sourceMap)); } - }); - } - } - @Override - public void onFailure(Exception e) { - onFailures(e); + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + }, this::onFailures)); + }, this::onFailures)); + } catch (Exception ex) { + onFailures(ex); } - }); - } + } else { + float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - @Override - public void onFailure(Exception e) { - log.error(e); - } - }); + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + insertFindings(timestampFeature, searchRequest, correlatedFindings, detectorType, correlationRules, orphanFinding); + } + }, this::onFailures)); + }, this::onFailures)); } else { log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); } - } - - @Override - public void onFailure(Exception e) { - - } - }); - } catch (Exception ex) { - onFailures(ex); - } - } else { - long findingTimestamp = this.request.getFinding().getTimestamp().toEpochMilli(); - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(1); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); - searchRequest.source(searchSourceBuilder); - searchRequest.preference(Preference.PRIMARY_FIRST.type()); + }, this::onFailures)); + } else { + long findingTimestamp = this.request.getFinding().getTimestamp().toEpochMilli(); + SearchRequest searchMetadataIndexRequest = getSearchMetadataIndexRequest(); - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { + client.search(searchMetadataIndexRequest, ActionListener.wrap(response -> { if (response.getHits().getHits().length == 0) { onFailures(new ResourceNotFoundException( "Failed to find hits in metadata index for finding id {}", request.getFinding().getId())); @@ -527,130 +370,123 @@ public void onResponse(SearchResponse response) { Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - if (findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL > scoreTimestamp) { + long newScoreTimestamp = findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL; + if (newScoreTimestamp > scoreTimestamp) { try { - XContentBuilder scoreBuilder = XContentFactory.jsonBuilder().startObject(); - scoreBuilder.field("scoreTimestamp", findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL); - scoreBuilder.field("root", false); - scoreBuilder.endObject(); - - IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) - .id(id) - .source(scoreBuilder) - .timeout(indexTimeout) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(scoreIndexRequest, new ActionListener<>() { - @Override - public void onResponse(IndexResponse response) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + IndexRequest scoreIndexRequest = getCorrelationMetadataIndexRequest(id, newScoreTimestamp); - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } + client.index(scoreIndexRequest, ActionListener.wrap(indexResponse -> { + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - @Override - public void onFailure(Exception e) { - onFailures(e); + SearchHit[] hits = searchResponse.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), + new CustomLogType(sourceMap)); } - }); - } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + }, this::onFailures)); + }, this::onFailures)); } catch (Exception ex) { onFailures(ex); } } else { float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.existsQuery("source")); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.fetchSource(true); - searchSourceBuilder.size(10000); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); - searchRequest.source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse response) { - if (response.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + insertFindings(timestampFeature, searchRequest, correlatedFindings, detectorType, correlationRules, orphanFinding); + } + }, this::onFailures)); + } + } catch (Exception ex) { + onFailures(ex); + } + } - SearchHit[] hits = response.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } + private SearchRequest getSearchLogTypeIndexRequest() { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.existsQuery("source")); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(10000); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); + searchRequest.source(searchSourceBuilder); + return searchRequest; + } - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), timestampFeature, logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - timestampFeature, correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature, logTypes); - } - } + private IndexRequest getCorrelationMetadataIndexRequest(String id, long newScoreTimestamp) throws IOException { + XContentBuilder scoreBuilder = XContentFactory.jsonBuilder().startObject(); + scoreBuilder.field("scoreTimestamp", newScoreTimestamp); + scoreBuilder.field("root", false); + scoreBuilder.endObject(); + + IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + .id(id) + .source(scoreBuilder) + .timeout(indexTimeout) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + return scoreIndexRequest; + } + private void insertFindings(float timestampFeature, SearchRequest searchRequest, Map> correlatedFindings, String detectorType, List correlationRules, Finding orphanFinding) { + client.search(searchRequest, ActionListener.wrap(response -> { + if (response.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - @Override - public void onFailure(Exception e) { - onFailures(e); - } - }); - } - } + SearchHit[] hits = response.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), + new CustomLogType(sourceMap)); + } - @Override - public void onFailure(Exception e) { - onFailures(e); + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), timestampFeature, logTypes); } - }); - } + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + timestampFeature, correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature, logTypes); + } + }, this::onFailures)); + } + + private SearchRequest getSearchMetadataIndexRequest() { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.fetchSource(true); + searchSourceBuilder.size(1); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); + searchRequest.source(searchSourceBuilder); + searchRequest.preference(Preference.PRIMARY_FIRST.type()); + + return searchRequest; } public void onOperation() { @@ -661,7 +497,8 @@ public void onOperation() { } public void onFailures(Exception t) { - log.error("Exception occurred while processing correlations", t); + log.error("Exception occurred while processing correlations for monitor id " + + request.getMonitorId() + " and finding id " + request.getFinding().getId(), t); if (counter.compareAndSet(false, true)) { finishHim(t); } diff --git a/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java b/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java index 02229a57c..624d76d58 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java +++ b/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java @@ -6,7 +6,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.admin.indices.alias.Alias; import org.opensearch.core.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequest; @@ -17,15 +16,11 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.health.ClusterIndexHealth; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.routing.IndexRoutingTable; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.rest.RestStatus; import java.io.IOException; import java.nio.charset.Charset; @@ -89,7 +84,7 @@ public boolean correlationMetadataIndexExists() { return clusterState.metadata().hasIndex(CORRELATION_METADATA_INDEX); } - public void setupCorrelationIndex(TimeValue indexTimeout, Long setupTimestamp, ActionListener listener) { + public void setupCorrelationIndex(TimeValue indexTimeout, Long setupTimestamp, ActionListener listener) throws IOException { try { long currentTimestamp = System.currentTimeMillis(); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -124,6 +119,7 @@ public void setupCorrelationIndex(TimeValue indexTimeout, Long setupTimestamp, A client.bulk(bulkRequest, listener); } catch (IOException ex) { log.error(ex); + throw ex; } } } \ No newline at end of file From 887739066de2d636d008711f12eec226d2b935b1 Mon Sep 17 00:00:00 2001 From: Surya Sashank Nistala Date: Wed, 6 Mar 2024 11:51:25 -0800 Subject: [PATCH 09/12] Pass rule field names in doc level queries during monitor/creation. Remove blocking actionGet() calls (#873) * pass query field names in doc level queries during monitor creation/updation Signed-off-by: Surya Sashank Nistala * remove actionGet() and change get index mapping call to event driven flow Signed-off-by: Surya Sashank Nistala * fix chained findings monitor Signed-off-by: Surya Sashank Nistala * add finding mappings Signed-off-by: Surya Sashank Nistala * remove test messages from logs Signed-off-by: Surya Sashank Nistala * revert build.gradle change Signed-off-by: Surya Sashank Nistala --------- Signed-off-by: Surya Sashank Nistala --- .../mapper/MapperService.java | 6 +- .../securityanalytics/mapper/MapperUtils.java | 8 +- .../rules/backend/OSQueryBackend.java | 5 +- .../TransportIndexDetectorAction.java | 490 ++++++++++++------ .../resthandler/DetectorMonitorRestApiIT.java | 2 +- 5 files changed, 346 insertions(+), 165 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java index 5616fdbe0..7760a4ac1 100644 --- a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java +++ b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java @@ -78,9 +78,11 @@ public void createMappingAction(String indexName, String logType, String aliasMa // since you can't update documents in non-write indices String index = indexName; boolean shouldUpsertIndexTemplate = IndexUtils.isConcreteIndex(indexName, this.clusterService.state()) == false; - if (IndexUtils.isDataStream(indexName, this.clusterService.state())) { + if (IndexUtils.isDataStream(indexName, this.clusterService.state()) || IndexUtils.isAlias(indexName, this.clusterService.state())) { + log.debug("{} is an alias or datastream. Fetching write index for create mapping action.", indexName); String writeIndex = IndexUtils.getWriteIndex(indexName, this.clusterService.state()); if (writeIndex != null) { + log.debug("Write index for {} is {}", indexName, writeIndex); index = writeIndex; } } @@ -92,6 +94,7 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { applyAliasMappings(getMappingsResponse.getMappings(), logType, aliasMappings, partial, new ActionListener<>() { @Override public void onResponse(Collection createMappingResponse) { + log.debug("Completed create mappings for {}", indexName); // We will return ack==false if one of the requests returned that // else return ack==true Optional notAckd = createMappingResponse.stream() @@ -110,6 +113,7 @@ public void onResponse(Collection createMappingResponse) { @Override public void onFailure(Exception e) { + log.debug("Failed to create mappings for {}", indexName ); actionListener.onFailure(e); } }); diff --git a/src/main/java/org/opensearch/securityanalytics/mapper/MapperUtils.java b/src/main/java/org/opensearch/securityanalytics/mapper/MapperUtils.java index 72dd36d11..8c8bf353f 100644 --- a/src/main/java/org/opensearch/securityanalytics/mapper/MapperUtils.java +++ b/src/main/java/org/opensearch/securityanalytics/mapper/MapperUtils.java @@ -5,6 +5,10 @@ package org.opensearch.securityanalytics.mapper; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -12,9 +16,6 @@ import java.util.Locale; import java.util.Map; import java.util.Set; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.cluster.metadata.MappingMetadata; -import org.opensearch.securityanalytics.util.SecurityAnalyticsException; public class MapperUtils { @@ -246,7 +247,6 @@ public void onError(String error) { } }); mappingsTraverser.traverse(); - return presentPathsMappings; } } diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java index 2d1763a43..ec7b09505 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java @@ -331,9 +331,12 @@ public Object convertConditionFieldEqValQueryExpr(ConditionFieldEqualsValueExpre @Override public Object convertConditionValStr(ConditionValueExpression condition) throws SigmaValueError { + String field = getFinalValueField(); + ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer")); SigmaString value = (SigmaString) condition.getValue(); boolean containsWildcard = value.containsWildcard(); - return String.format(Locale.getDefault(), (containsWildcard? this.unboundWildcardExpression: this.unboundValueStrExpression), this.convertValueStr((SigmaString) condition.getValue())); + return String.format(Locale.getDefault(), (containsWildcard? this.unboundWildcardExpression: this.unboundValueStrExpression), + this.convertValueStr((SigmaString) condition.getValue())); } @Override diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 883bf8ee7..ad90b795f 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -110,15 +110,17 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.io.IOException; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.concurrent.CountDownLatch; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -220,19 +222,22 @@ private void checkIndicesAndExecute( ActionListener listener, User user ) { + log.debug("check indices and execute began"); String [] detectorIndices = request.getDetector().getInputs().stream().flatMap(detectorInput -> detectorInput.getIndices().stream()).toArray(String[]::new); SearchRequest searchRequest = new SearchRequest(detectorIndices) - .source(SearchSourceBuilder.searchSource().size(1).query(QueryBuilders.matchAllQuery())) - .preference(Preference.PRIMARY_FIRST.type()); + .source(SearchSourceBuilder.searchSource().size(1).query(QueryBuilders.matchAllQuery())); + searchRequest.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30)); client.search(searchRequest, new ActionListener<>() { @Override public void onResponse(SearchResponse searchResponse) { + log.debug("check indices and execute completed. Took {} millis", searchResponse.getTook().millis()); AsyncIndexDetectorsAction asyncAction = new AsyncIndexDetectorsAction(user, task, request, listener); asyncAction.start(); } @Override public void onFailure(Exception e) { + log.debug("check indices and execute failed", e); if (e instanceof OpenSearchStatusException) { listener.onFailure(SecurityAnalyticsException.wrap( new OpenSearchStatusException(String.format(Locale.getDefault(), "User doesn't have read permissions for one or more configured index %s", detectorIndices), RestStatus.FORBIDDEN) @@ -249,7 +254,8 @@ public void onFailure(Exception e) { }); } - private void createMonitorFromQueries(List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) { + private void createMonitorFromQueries(List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy, + List queryFieldNames) { List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( Collectors.toList()); List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( @@ -262,13 +268,14 @@ public void onResponse(List dlqs) { List monitorRequests = new ArrayList<>(); if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { - monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, dlqs != null ? dlqs : List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); + monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, dlqs != null ? dlqs : List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames)); } if (!bucketLevelRules.isEmpty()) { StepListener> bucketLevelMonitorRequests = new StepListener<>(); buildBucketLevelMonitorRequests(bucketLevelRules, detector, refreshPolicy, Monitor.NO_ID, Method.POST, bucketLevelMonitorRequests); bucketLevelMonitorRequests.whenComplete(indexMonitorRequests -> { + log.debug("bucket level monitor request built"); monitorRequests.addAll(indexMonitorRequests); // Do nothing if detector doesn't have any monitor if (monitorRequests.isEmpty()) { @@ -283,6 +290,7 @@ public void onResponse(List dlqs) { // https://github.com/opensearch-project/alerting/issues/646 AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, addFirstMonitorStep); addFirstMonitorStep.whenComplete(addedFirstMonitorResponse -> { + log.debug("first monitor created id {} of type {}", addedFirstMonitorResponse.getId(), addedFirstMonitorResponse.getMonitor().getMonitorType()); monitorResponses.add(addedFirstMonitorResponse); StepListener> indexMonitorsStep = new StepListener<>(); @@ -416,7 +424,12 @@ public void onFailure(Exception e) { } } - private void updateMonitorFromQueries(String index, List> rulesById, Detector detector, ActionListener> listener, WriteRequest.RefreshPolicy refreshPolicy) throws Exception { + private void updateMonitorFromQueries(String index, + List> rulesById, + Detector detector, + ActionListener> listener, + WriteRequest.RefreshPolicy refreshPolicy, + List queryFieldNames) { List monitorsToBeUpdated = new ArrayList<>(); List> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect( @@ -442,47 +455,78 @@ public void onResponse(Map> ruleFieldMappings) { // Pair of RuleId - MonitorId for existing monitors of the detector Map monitorPerRule = detector.getRuleIdMonitorIdMap(); + GroupedActionListener groupedActionListener = new GroupedActionListener<>( + new ActionListener<>() { + @Override + public void onResponse(Collection indexMonitorRequests) { + onIndexMonitorRequestCreation( + monitorsToBeUpdated, + monitorsToBeAdded, + rulesById, + detector, + refreshPolicy, + docLevelQueries, + queryFieldNames, + listener + ); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, bucketLevelRules.size() + ); for (Pair query : bucketLevelRules) { Rule rule = query.getRight(); if (rule.getAggregationQueries() != null) { // Detect if the monitor should be added or updated if (monitorPerRule.containsKey(rule.getId())) { String monitorId = monitorPerRule.get(rule.getId()); - monitorsToBeUpdated.add(createBucketLevelMonitorRequest(query.getRight(), + createBucketLevelMonitorRequest(query.getRight(), detector, refreshPolicy, monitorId, Method.PUT, - queryBackendMap.get(rule.getCategory()))); + queryBackendMap.get(rule.getCategory()), + new ActionListener<>() { + @Override + public void onResponse(IndexMonitorRequest indexMonitorRequest) { + monitorsToBeUpdated.add(indexMonitorRequest); + groupedActionListener.onResponse(indexMonitorRequest); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to create bucket level monitor request", e); + listener.onFailure(e); + } + }); } else { - monitorsToBeAdded.add(createBucketLevelMonitorRequest(query.getRight(), + createBucketLevelMonitorRequest(query.getRight(), detector, refreshPolicy, Monitor.NO_ID, Method.POST, - queryBackendMap.get(rule.getCategory()))); + queryBackendMap.get(rule.getCategory()), + new ActionListener<>() { + @Override + public void onResponse(IndexMonitorRequest indexMonitorRequest) { + monitorsToBeAdded.add(indexMonitorRequest); + groupedActionListener.onResponse(indexMonitorRequest); + + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to create bucket level monitor request", e); + listener.onFailure(e); + } + }); } } } - List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( - Collectors.toList()); - - // Process doc level monitors - if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { - if (detector.getDocLevelMonitorId() == null) { - monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); - } else { - monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT)); - } - } - - List monitorIdsToBeDeleted = detector.getRuleIdMonitorIdMap().values().stream().collect(Collectors.toList()); - monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( - Collectors.toList())); - - updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } catch (Exception ex) { listener.onFailure(ex); } @@ -494,23 +538,16 @@ public void onFailure(Exception e) { } }); } else { - List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( - Collectors.toList()); - - // Process doc level monitors - if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { - if (detector.getDocLevelMonitorId() == null) { - monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST)); - } else { - monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT)); - } - } - - List monitorIdsToBeDeleted = detector.getRuleIdMonitorIdMap().values().stream().collect(Collectors.toList()); - monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( - Collectors.toList())); - - updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + onIndexMonitorRequestCreation( + monitorsToBeUpdated, + monitorsToBeAdded, + rulesById, + detector, + refreshPolicy, + docLevelQueries, + queryFieldNames, + listener + ); } } @@ -521,6 +558,33 @@ public void onFailure(Exception e) { }); } + private void onIndexMonitorRequestCreation(List monitorsToBeUpdated, + List monitorsToBeAdded, + List> rulesById, + Detector detector, + RefreshPolicy refreshPolicy, + List docLevelQueries, + List queryFieldNames, + ActionListener> listener) { + List> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect( + Collectors.toList()); + + // Process doc level monitors + if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) { + if (detector.getDocLevelMonitorId() == null) { + monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames)); + } else { + monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT, queryFieldNames)); + } + } + + List monitorIdsToBeDeleted = detector.getRuleIdMonitorIdMap().values().stream().collect(Collectors.toList()); + monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( + Collectors.toList())); + + updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + } + /** * Update list of monitors for the given detector * Executed in a steps: @@ -663,7 +727,7 @@ public void onFailure(Exception e) { } } - private IndexMonitorRequest createDocLevelMonitorRequest(List> queries, List threatIntelQueries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) { + private IndexMonitorRequest createDocLevelMonitorRequest(List> queries, List threatIntelQueries, Detector detector, RefreshPolicy refreshPolicy, String monitorId, Method restMethod, List queryFieldNames) { List docLevelMonitorInputs = new ArrayList<>(); List docLevelQueries = new ArrayList<>(); @@ -673,7 +737,6 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List Rule rule = query.getRight(); String name = query.getLeft(); - String actualQuery = rule.getQueries().get(0).getValue(); List tags = new ArrayList<>(); @@ -681,7 +744,7 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List tags.add(rule.getCategory()); tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList())); - DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, Collections.emptyList(), actualQuery, tags); + DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, Collections.emptyList(), actualQuery, tags, queryFieldNames); docLevelQueries.add(docLevelQuery); } docLevelQueries.addAll(threatIntelQueries); @@ -788,43 +851,75 @@ private IndexMonitorRequest createDocLevelMonitorMatchAllRequest( } private void buildBucketLevelMonitorRequests(List> queries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod, ActionListener> listener) throws Exception { - + log.debug("bucket level monitor request starting"); + log.debug("get rule field mappings request being made"); logTypeService.getRuleFieldMappings(new ActionListener<>() { @Override public void onResponse(Map> ruleFieldMappings) { - try { + log.debug("got rule field mapping success"); List ruleCategories = queries.stream().map(Pair::getRight).map(Rule::getCategory).distinct().collect( Collectors.toList()); Map queryBackendMap = new HashMap<>(); for(String category: ruleCategories) { Map fieldMappings = ruleFieldMappings.get(category); - queryBackendMap.put(category, new OSQueryBackend(fieldMappings, true, true)); + try { + queryBackendMap.put(category, new OSQueryBackend(fieldMappings, true, true)); + } catch (IOException e) { + logger.error("Failed to create OSQueryBackend from field mappings", e); + listener.onFailure(e); + } } List monitorRequests = new ArrayList<>(); + GroupedActionListener bucketLevelMonitorRequestsListener = new GroupedActionListener<>( + new ActionListener<>() { + @Override + public void onResponse(Collection indexMonitorRequests) { + // if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger + if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) { + monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST)); + } + listener.onResponse(monitorRequests); + } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, queries.size() + ); for (Pair query: queries) { Rule rule = query.getRight(); // Creating bucket level monitor per each aggregation rule - if (rule.getAggregationQueries() != null){ - monitorRequests.add(createBucketLevelMonitorRequest( + if (rule.getAggregationQueries() != null) { + createBucketLevelMonitorRequest( query.getRight(), detector, refreshPolicy, - Monitor.NO_ID, - Method.POST, - queryBackendMap.get(rule.getCategory()))); + monitorId, + restMethod, + queryBackendMap.get(rule.getCategory()), + new ActionListener<>() { + @Override + public void onResponse(IndexMonitorRequest indexMonitorRequest) { + monitorRequests.add(indexMonitorRequest); + bucketLevelMonitorRequestsListener.onResponse(indexMonitorRequest); + } + + + @Override + public void onFailure(Exception e) { + logger.error("Failed to build bucket level monitor requests", e); + bucketLevelMonitorRequestsListener.onFailure(e); + } + }); + + } else { + log.debug("Aggregation query is null in rule {}", rule.getId()); + bucketLevelMonitorRequestsListener.onResponse(null); } } - // if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger - if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) { - monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId()+"_chained_findings", Method.POST)); - } - listener.onResponse(monitorRequests); - } catch (Exception ex) { - listener.onFailure(ex); - } } @Override @@ -834,94 +929,110 @@ public void onFailure(Exception e) { }); } - private IndexMonitorRequest createBucketLevelMonitorRequest( + private void createBucketLevelMonitorRequest( Rule rule, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod, - QueryBackend queryBackend - ) throws SigmaError { - + QueryBackend queryBackend, + ActionListener listener + ) { + log.debug(":create bucket level monitor response starting"); List indices = detector.getInputs().get(0).getIndices(); - - AggregationItem aggItem = rule.getAggregationItemsFromRule().get(0); - AggregationQueries aggregationQueries = queryBackend.convertAggregation(aggItem); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .seqNoAndPrimaryTerm(true) - .version(true) - // Build query string filter - .query(QueryBuilders.queryStringQuery(rule.getQueries().get(0).getValue())) - .aggregation(aggregationQueries.getAggBuilder()); - // input index can also be an index pattern or alias so we have to resolve it to concrete index - String concreteIndex = IndexUtils.getNewIndexByCreationDate( - clusterService.state(), - indexNameExpressionResolver, - indices.get(0) // taking first one is fine because we expect that all indices in list share same mappings - ); try { - GetIndexMappingsResponse getIndexMappingsResponse = client.execute( + AggregationItem aggItem = rule.getAggregationItemsFromRule().get(0); + AggregationQueries aggregationQueries = queryBackend.convertAggregation(aggItem); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .seqNoAndPrimaryTerm(true) + .version(true) + // Build query string filter + .query(QueryBuilders.queryStringQuery(rule.getQueries().get(0).getValue())) + .aggregation(aggregationQueries.getAggBuilder()); + // input index can also be an index pattern or alias so we have to resolve it to concrete index + String concreteIndex = IndexUtils.getNewIndexByCreationDate( + clusterService.state(), + indexNameExpressionResolver, + indices.get(0) // taking first one is fine because we expect that all indices in list share same mappings + ); + client.execute( GetIndexMappingsAction.INSTANCE, - new GetIndexMappingsRequest(concreteIndex)) - .actionGet(); - MappingMetadata mappingMetadata = getIndexMappingsResponse.mappings().get(concreteIndex); - List> pairs = MapperUtils.getAllAliasPathPairs(mappingMetadata); - boolean timeStampAliasPresent = pairs. - stream() - .anyMatch(p -> - TIMESTAMP_FIELD_ALIAS.equals(p.getLeft()) || TIMESTAMP_FIELD_ALIAS.equals(p.getRight())); - if(timeStampAliasPresent) { - BoolQueryBuilder boolQueryBuilder = searchSourceBuilder.query() == null - ? new BoolQueryBuilder() - : QueryBuilders.boolQuery().must(searchSourceBuilder.query()); - RangeQueryBuilder timeRangeFilter = QueryBuilders.rangeQuery(TIMESTAMP_FIELD_ALIAS) - .gt("{{period_end}}||-" + (aggItem.getTimeframe() != null? aggItem.getTimeframe(): "1h")) - .lte("{{period_end}}") - .format("epoch_millis"); - boolQueryBuilder.must(timeRangeFilter); - searchSourceBuilder.query(boolQueryBuilder); - } - } catch (Exception e) { - log.error( - String.format(Locale.getDefault(), - "Unable to verify presence of timestamp alias for index [%s] in detector [%s]. Not setting time range filter for bucket level monitor.", - concreteIndex, detector.getName()), e); - } - - List bucketLevelMonitorInputs = new ArrayList<>(); - bucketLevelMonitorInputs.add(new SearchInput(indices, searchSourceBuilder)); - - List triggers = new ArrayList<>(); - BucketLevelTrigger bucketLevelTrigger = new BucketLevelTrigger(rule.getId(), rule.getTitle(), rule.getLevel(), aggregationQueries.getCondition(), - Collections.emptyList()); - triggers.add(bucketLevelTrigger); - - /** TODO - Think how to use detector trigger - List detectorTriggers = detector.getTriggers(); - for (DetectorTrigger detectorTrigger: detectorTriggers) { - String id = detectorTrigger.getId(); - String name = detectorTrigger.getName(); - String severity = detectorTrigger.getSeverity(); - List actions = detectorTrigger.getActions(); - Script condition = detectorTrigger.convertToCondition(); - - BucketLevelTrigger bucketLevelTrigger1 = new BucketLevelTrigger(id, name, severity, condition, actions); - triggers.add(bucketLevelTrigger1); - } **/ - - Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), false, detector.getSchedule(), detector.getLastUpdateTime(), null, - MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), - new DataSources(detector.getRuleIndex(), - detector.getFindingsIndex(), - detector.getFindingsIndexPattern(), - detector.getAlertsIndex(), - detector.getAlertsHistoryIndex(), - detector.getAlertsHistoryIndexPattern(), - DetectorMonitorConfig.getRuleIndexMappingsByType(), - true), PLUGIN_OWNER_FIELD); + new GetIndexMappingsRequest(concreteIndex), + new ActionListener() { + @Override + public void onResponse(GetIndexMappingsResponse getIndexMappingsResponse) { + MappingMetadata mappingMetadata = getIndexMappingsResponse.mappings().get(concreteIndex); + List> pairs = null; + try { + pairs = MapperUtils.getAllAliasPathPairs(mappingMetadata); + } catch (IOException e) { + logger.debug("Failed to get alias path pairs from mapping metadata", e); + onFailure(e); + } + boolean timeStampAliasPresent = pairs. + stream() + .anyMatch(p -> + TIMESTAMP_FIELD_ALIAS.equals(p.getLeft()) || TIMESTAMP_FIELD_ALIAS.equals(p.getRight())); + if (timeStampAliasPresent) { + BoolQueryBuilder boolQueryBuilder = searchSourceBuilder.query() == null + ? new BoolQueryBuilder() + : QueryBuilders.boolQuery().must(searchSourceBuilder.query()); + RangeQueryBuilder timeRangeFilter = QueryBuilders.rangeQuery(TIMESTAMP_FIELD_ALIAS) + .gt("{{period_end}}||-" + (aggItem.getTimeframe() != null ? aggItem.getTimeframe() : "1h")) + .lte("{{period_end}}") + .format("epoch_millis"); + boolQueryBuilder.must(timeRangeFilter); + searchSourceBuilder.query(boolQueryBuilder); + } + List bucketLevelMonitorInputs = new ArrayList<>(); + bucketLevelMonitorInputs.add(new SearchInput(indices, searchSourceBuilder)); + + List triggers = new ArrayList<>(); + BucketLevelTrigger bucketLevelTrigger = new BucketLevelTrigger(rule.getId(), rule.getTitle(), rule.getLevel(), aggregationQueries.getCondition(), + Collections.emptyList()); + triggers.add(bucketLevelTrigger); + + /** TODO - Think how to use detector trigger + List detectorTriggers = detector.getTriggers(); + for (DetectorTrigger detectorTrigger: detectorTriggers) { + String id = detectorTrigger.getId(); + String name = detectorTrigger.getName(); + String severity = detectorTrigger.getSeverity(); + List actions = detectorTrigger.getActions(); + Script condition = detectorTrigger.convertToCondition(); + + BucketLevelTrigger bucketLevelTrigger1 = new BucketLevelTrigger(id, name, severity, condition, actions); + triggers.add(bucketLevelTrigger1); + } **/ + + Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), false, detector.getSchedule(), detector.getLastUpdateTime(), null, + MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), + new DataSources(detector.getRuleIndex(), + detector.getFindingsIndex(), + detector.getFindingsIndexPattern(), + detector.getAlertsIndex(), + detector.getAlertsHistoryIndex(), + detector.getAlertsHistoryIndexPattern(), + DetectorMonitorConfig.getRuleIndexMappingsByType(), + true), PLUGIN_OWNER_FIELD); + + listener.onResponse(new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null)); + } - return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null); + @Override + public void onFailure(Exception e) { + log.error( + String.format(Locale.getDefault(), + "Unable to verify presence of timestamp alias for index [%s] in detector [%s]. Not setting time range filter for bucket level monitor.", + concreteIndex, detector.getName()), e); + listener.onFailure(e); + } + }); + } catch (SigmaError e) { + log.error("Failed to create bucket level monitor request", e); + listener.onFailure(e); + } } /** @@ -996,21 +1107,27 @@ class AsyncIndexDetectorsAction { } void start() { + log.debug("stash context"); TransportIndexDetectorAction.this.threadPool.getThreadContext().stashContext(); - + log.debug("log type check : {}", request.getDetector().getDetectorType()); logTypeService.doesLogTypeExist(request.getDetector().getDetectorType().toLowerCase(Locale.ROOT), new ActionListener<>() { @Override public void onResponse(Boolean exist) { if (exist) { + log.debug("log type exists : {}", request.getDetector().getDetectorType()); try { if (!detectorIndices.detectorIndexExists()) { + log.debug("detector index creation"); detectorIndices.initDetectorIndex(new ActionListener<>() { @Override public void onResponse(CreateIndexResponse response) { try { + log.debug("detector index created in {}"); + onCreateMappingsResponse(response); prepareDetectorIndexing(); } catch (Exception e) { + log.debug("detector index creation failed", e); onFailures(e); } } @@ -1021,16 +1138,19 @@ public void onFailure(Exception e) { } }); } else if (!IndexUtils.detectorIndexUpdated) { + log.debug("detector index update mapping"); IndexUtils.updateIndexMapping( Detector.DETECTORS_INDEX, DetectorIndices.detectorMappings(), clusterService.state(), client.admin().indices(), new ActionListener<>() { @Override public void onResponse(AcknowledgedResponse response) { + log.debug("detector index mapping updated"); onUpdateMappingsResponse(response); try { prepareDetectorIndexing(); } catch (Exception e) { + log.debug("detector index mapping FAILED updation", e); onFailures(e); } } @@ -1088,24 +1208,28 @@ void createDetector() { if (!detector.getInputs().isEmpty()) { try { + log.debug("init rule index template"); ruleTopicIndices.initRuleTopicIndexTemplate(new ActionListener<>() { @Override public void onResponse(AcknowledgedResponse acknowledgedResponse) { - + log.debug("init rule index template ack"); initRuleIndexAndImportRules(request, new ActionListener<>() { @Override public void onResponse(List monitorResponses) { + log.debug("monitors indexed"); request.getDetector().setMonitorIds(getMonitorIds(monitorResponses)); request.getDetector().setRuleIdMonitorIdMap(mapMonitorIds(monitorResponses)); try { indexDetector(); } catch (Exception e) { + logger.debug("create detector failed", e); onFailures(e); } } @Override public void onFailure(Exception e) { + logger.debug("import rules failed", e); onFailures(e); } }); @@ -1113,10 +1237,12 @@ public void onFailure(Exception e) { @Override public void onFailure(Exception e) { + logger.debug("init rules index failed", e); onFailures(e); } }); } catch (Exception e) { + logger.debug("init rules index failed", e); onFailures(e); } } @@ -1233,11 +1359,13 @@ public void initRuleIndexAndImportRules(IndexDetectorRequest request, ActionList new ActionListener<>() { @Override public void onResponse(CreateIndexResponse response) { + log.debug("prepackaged rule index created"); ruleIndices.onCreateMappingsResponse(response, true); ruleIndices.importRules(RefreshPolicy.IMMEDIATE, indexTimeout, new ActionListener<>() { @Override public void onResponse(BulkResponse response) { + log.debug("rules imported"); if (!response.hasFailures()) { importRules(request, listener); } else { @@ -1247,6 +1375,7 @@ public void onResponse(BulkResponse response) { @Override public void onFailure(Exception e) { + log.debug("failed to import rules", e); onFailures(e); } }); @@ -1358,13 +1487,14 @@ public void importRules(IndexDetectorRequest request, ActionListener() { @Override public void onResponse(SearchResponse response) { if (response.isTimedOut()) { onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } + logger.debug("prepackaged rules fetch success"); SearchHits hits = response.getHits(); List> queries = new ArrayList<>(); @@ -1387,13 +1517,10 @@ public void onResponse(SearchResponse response) { } else if (detectorInput.getCustomRules().size() > 0) { onFailures(new OpenSearchStatusException("Custom Rule Index not found", RestStatus.NOT_FOUND)); } else { - if (request.getMethod() == RestRequest.Method.POST) { - createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy()); - } else if (request.getMethod() == RestRequest.Method.PUT) { - updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); - } + resolveRuleFieldNamesAndUpsertMonitorFromQueries(queries, detector, logIndex, listener); } } catch (Exception e) { + logger.debug("failed to fetch prepackaged rules", e); onFailures(e); } } @@ -1405,6 +1532,56 @@ public void onFailure(Exception e) { }); } + private void resolveRuleFieldNamesAndUpsertMonitorFromQueries(List> queries, Detector detector, String logIndex, ActionListener> listener) { + logger.error("PERF_DEBUG_SAP: Fetching alias path pairs to construct rule_field_names"); + long start = System.currentTimeMillis(); + Set ruleFieldNames = new HashSet<>(); + for (Pair query : queries) { + List queryFieldNames = query.getValue().getQueryFieldNames().stream().map(Value::getValue).collect(Collectors.toList()); + ruleFieldNames.addAll(queryFieldNames); + } + client.execute(GetIndexMappingsAction.INSTANCE, new GetIndexMappingsRequest(logIndex), new ActionListener<>() { + @Override + public void onResponse(GetIndexMappingsResponse getMappingsViewResponse) { + try { + List> aliasPathPairs; + + aliasPathPairs = MapperUtils.getAllAliasPathPairs(getMappingsViewResponse.getMappings().get(logIndex)); + for (Pair aliasPathPair : aliasPathPairs) { + if (ruleFieldNames.contains(aliasPathPair.getLeft())) { + ruleFieldNames.remove(aliasPathPair.getLeft()); + ruleFieldNames.add(aliasPathPair.getRight()); + } + } + long took = System.currentTimeMillis() - start; + log.debug("completed collecting rule_field_names in {} millis", took); + + } catch (Exception e) { + logger.error("Failure in parsing rule field names/aliases while " + + detector.getId() == null ? "creating" : "updating" + + " detector. Not optimizing detector queries with relevant fields", e); + ruleFieldNames.clear(); + } + upsertMonitorQueries(queries, detector, listener, ruleFieldNames, logIndex); + + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to fetch mappings view response for log index " + logIndex, e); + listener.onFailure(e); + } + }); + } + + private void upsertMonitorQueries(List> queries, Detector detector, ActionListener> listener, Set ruleFieldNames, String logIndex) { + if (request.getMethod() == Method.POST) { + createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy(), new ArrayList<>(ruleFieldNames)); + } else if (request.getMethod() == Method.PUT) { + updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy(), new ArrayList<>(ruleFieldNames)); + } + } + @SuppressWarnings("unchecked") public void importCustomRules(Detector detector, DetectorInput detectorInput, List> queries, ActionListener> listener) { final String logIndex = detectorInput.getIndices().get(0); @@ -1418,14 +1595,14 @@ public void importCustomRules(Detector detector, DetectorInput detectorInput, Li .query(queryBuilder) .size(10000)) .preference(Preference.PRIMARY_FIRST.type()); - + logger.debug("importing custom rules"); client.search(searchRequest, new ActionListener<>() { @Override public void onResponse(SearchResponse response) { if (response.isTimedOut()) { onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - + logger.debug("custom rules fetch successful"); SearchHits hits = response.getHits(); try { @@ -1441,11 +1618,7 @@ public void onResponse(SearchResponse response) { queries.add(Pair.of(id, rule)); } - if (request.getMethod() == RestRequest.Method.POST) { - createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy()); - } else if (request.getMethod() == RestRequest.Method.PUT) { - updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy()); - } + resolveRuleFieldNamesAndUpsertMonitorFromQueries(queries, detector, logIndex, listener); } catch (Exception ex) { onFailures(ex); } @@ -1473,10 +1646,11 @@ public void indexDetector() throws Exception { .id(request.getDetectorId()) .timeout(indexTimeout); } - + log.debug("indexing detector"); client.index(indexRequest, new ActionListener<>() { @Override public void onResponse(IndexResponse response) { + log.debug("detector indexed success."); Detector responseDetector = request.getDetector(); responseDetector.setId(response.getId()); onOperation(response, responseDetector); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index 3a11300ee..8de88a717 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -2056,7 +2056,7 @@ public void testCreateDetectorWithCloudtrailAggrRuleWithEcsFields() throws IOExc // both req params and req body are supported createMappingRequest.setJsonEntity( "{\n" + - " \"index_name\": \"" + index + "\",\n" + + " \"index_name\": \"cloudtrail\",\n" + " \"rule_topic\": \"cloudtrail\",\n" + " \"partial\": true,\n" + " \"alias_mappings\": {\n" + From 689760e897294530dea0d1181c2539e9b607d23c Mon Sep 17 00:00:00 2001 From: Joanne Wang Date: Thu, 7 Mar 2024 09:40:44 -0800 Subject: [PATCH 10/12] Fix duplicate ecs mappings which returns incorrect log index field in mapping view API (#786) (#788) * field mapping changes Signed-off-by: Joanne Wang * add integ test Signed-off-by: Joanne Wang * turn unmappedfieldaliases as set and add integ test Signed-off-by: Joanne Wang * add comments Signed-off-by: Joanne Wang * fix integ tests Signed-off-by: Joanne Wang * moved logic to method for better readability Signed-off-by: Joanne Wang --------- Signed-off-by: Joanne Wang --- .../mapper/MapperService.java | 48 ++++- .../mapper/MapperRestApiIT.java | 171 ++++++++++++++++++ .../resthandler/OCSFDetectorRestApiIT.java | 4 +- 3 files changed, 212 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java index 7760a4ac1..42b374735 100644 --- a/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java +++ b/src/main/java/org/opensearch/securityanalytics/mapper/MapperService.java @@ -8,7 +8,6 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.admin.indices.get.GetIndexRequest; import org.opensearch.action.admin.indices.get.GetIndexResponse; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; @@ -485,13 +484,16 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { String rawPath = requiredField.getRawField(); String ocsfPath = requiredField.getOcsf(); if (allFieldsFromIndex.contains(rawPath)) { - if (alias != null) { - // Maintain list of found paths in index - applyableAliases.add(alias); - } else { - applyableAliases.add(rawPath); + // if the alias was already added into applyable aliases, then skip to avoid duplicates + if (!applyableAliases.contains(alias) && !applyableAliases.contains(rawPath)) { + if (alias != null) { + // Maintain list of found paths in index + applyableAliases.add(alias); + } else { + applyableAliases.add(rawPath); + } + pathsOfApplyableAliases.add(rawPath); } - pathsOfApplyableAliases.add(rawPath); } else if (allFieldsFromIndex.contains(ocsfPath)) { applyableAliases.add(alias); pathsOfApplyableAliases.add(ocsfPath); @@ -505,13 +507,21 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { } } + // turn unmappedFieldAliases into a set to remove duplicates + Set setOfUnmappedFieldAliases = new HashSet<>(unmappedFieldAliases); + + // filter out aliases that were included in applyableAliases already + List filteredUnmappedFieldAliases = setOfUnmappedFieldAliases.stream() + .filter(e -> false == applyableAliases.contains(e)) + .collect(Collectors.toList()); + Map> aliasMappingFields = new HashMap<>(); XContentBuilder aliasMappingsObj = XContentFactory.jsonBuilder().startObject(); for (LogType.Mapping mapping : requiredFields) { if (allFieldsFromIndex.contains(mapping.getOcsf())) { aliasMappingFields.put(mapping.getEcs(), Map.of("type", "alias", "path", mapping.getOcsf())); } else if (mapping.getEcs() != null) { - aliasMappingFields.put(mapping.getEcs(), Map.of("type", "alias", "path", mapping.getRawField())); + shouldUpdateEcsMappingAndMaybeUpdates(mapping, aliasMappingFields, pathsOfApplyableAliases); } else if (mapping.getEcs() == null) { aliasMappingFields.put(mapping.getRawField(), Map.of("type", "alias", "path", mapping.getRawField())); } @@ -527,7 +537,7 @@ public void onResponse(GetMappingsResponse getMappingsResponse) { .filter(e -> pathsOfApplyableAliases.contains(e) == false) .collect(Collectors.toList()); actionListener.onResponse( - new GetMappingsViewResponse(aliasMappings, unmappedIndexFields, unmappedFieldAliases, logTypeService.getIocFieldsList(logType)) + new GetMappingsViewResponse(aliasMappings, unmappedIndexFields, filteredUnmappedFieldAliases, logTypeService.getIocFieldsList(logType)) ); } catch (Exception e) { actionListener.onFailure(e); @@ -542,6 +552,26 @@ public void onFailure(Exception e) { }); } + /** + * Only updates the alias mapping fields if the ecs key has not been mapped yet + * or if pathOfApplyableAliases contains the raw field + * + * @param mapping + * @param aliasMappingFields + * @param pathsOfApplyableAliases + */ + private static void shouldUpdateEcsMappingAndMaybeUpdates(LogType.Mapping mapping, Map> aliasMappingFields, List pathsOfApplyableAliases) { + // check if aliasMappingFields already contains a key + if (aliasMappingFields.containsKey(mapping.getEcs())) { + // if the pathOfApplyableAliases contains the raw field, then override the existing map + if (pathsOfApplyableAliases.contains(mapping.getRawField())) { + aliasMappingFields.put(mapping.getEcs(), Map.of("type", "alias", "path", mapping.getRawField())); + } + } else { + aliasMappingFields.put(mapping.getEcs(), Map.of("type", "alias", "path", mapping.getRawField())); + } + } + /** * Given index name, resolves it to single concrete index, depending on what initial indexName is. * In case of Datastream or Alias, WriteIndex would be returned. In case of index pattern, newest index by creation date would be returned. diff --git a/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java index ce86187d2..e32f19371 100644 --- a/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/mapper/MapperRestApiIT.java @@ -395,6 +395,114 @@ public void testGetMappingsViewLinuxSuccess() throws IOException { assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); } + // Tests mappings where multiple raw fields correspond to one ecs value + public void testGetMappingsViewWindowsSuccess() throws IOException { + + String testIndexName = "get_mappings_view_index"; + + createSampleWindex(testIndexName); + + // Execute GetMappingsViewAction to add alias mapping for index + Request request = new Request("GET", SecurityAnalyticsPlugin.MAPPINGS_VIEW_BASE_URI); + // both req params and req body are supported + request.addParameter("index_name", testIndexName); + request.addParameter("rule_topic", "windows"); + Response response = client().performRequest(request); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + Map respMap = responseAsMap(response); + + // Verify alias mappings + Map props = (Map) respMap.get("properties"); + assertEquals(3, props.size()); + assertTrue(props.containsKey("winlog.event_data.LogonType")); + assertTrue(props.containsKey("winlog.provider_name")); + assertTrue(props.containsKey("host.hostname")); + + // Verify unmapped index fields + List unmappedIndexFields = (List) respMap.get("unmapped_index_fields"); + assertEquals(3, unmappedIndexFields.size()); + assert(unmappedIndexFields.contains("plain1")); + assert(unmappedIndexFields.contains("ParentUser.first")); + assert(unmappedIndexFields.contains("ParentUser.last")); + + // Verify unmapped field aliases + List filteredUnmappedFieldAliases = (List) respMap.get("unmapped_field_aliases"); + assertEquals(191, filteredUnmappedFieldAliases.size()); + assert(!filteredUnmappedFieldAliases.contains("winlog.event_data.LogonType")); + assert(!filteredUnmappedFieldAliases.contains("winlog.provider_name")); + assert(!filteredUnmappedFieldAliases.contains("host.hostname")); + List> iocFieldsList = (List>) respMap.get(GetMappingsViewResponse.THREAT_INTEL_FIELD_ALIASES); + assertEquals(iocFieldsList.size(), 1); + + // Index a doc for a field with multiple raw fields corresponding to one ecs field + indexDoc(testIndexName, "1", "{ \"EventID\": 1 }"); + // Execute GetMappingsViewAction to add alias mapping for index + request = new Request("GET", SecurityAnalyticsPlugin.MAPPINGS_VIEW_BASE_URI); + // both req params and req body are supported + request.addParameter("index_name", testIndexName); + request.addParameter("rule_topic", "windows"); + response = client().performRequest(request); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + respMap = responseAsMap(response); + + // Verify alias mappings + props = (Map) respMap.get("properties"); + assertEquals(4, props.size()); + assertTrue(props.containsKey("winlog.event_id")); + + // verify unmapped index fields + unmappedIndexFields = (List) respMap.get("unmapped_index_fields"); + assertEquals(3, unmappedIndexFields.size()); + + // verify unmapped field aliases + filteredUnmappedFieldAliases = (List) respMap.get("unmapped_field_aliases"); + assertEquals(190, filteredUnmappedFieldAliases.size()); + assert(!filteredUnmappedFieldAliases.contains("winlog.event_id")); + } + + // Tests mappings where multiple raw fields correspond to one ecs value and all fields are present in the index + public void testGetMappingsViewMulitpleRawFieldsSuccess() throws IOException { + + String testIndexName = "get_mappings_view_index"; + + createSampleWindex(testIndexName); + String sampleDoc = "{" + + " \"EventID\": 1," + + " \"EventId\": 2," + + " \"event_uid\": 3" + + "}"; + indexDoc(testIndexName, "1", sampleDoc); + + // Execute GetMappingsViewAction to add alias mapping for index + Request request = new Request("GET", SecurityAnalyticsPlugin.MAPPINGS_VIEW_BASE_URI); + // both req params and req body are supported + request.addParameter("index_name", testIndexName); + request.addParameter("rule_topic", "windows"); + Response response = client().performRequest(request); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + Map respMap = responseAsMap(response); + + // Verify alias mappings + Map props = (Map) respMap.get("properties"); + assertEquals(4, props.size()); + assertTrue(props.containsKey("winlog.event_data.LogonType")); + assertTrue(props.containsKey("winlog.provider_name")); + assertTrue(props.containsKey("host.hostname")); + assertTrue(props.containsKey("winlog.event_id")); + + // Verify unmapped index fields + List unmappedIndexFields = (List) respMap.get("unmapped_index_fields"); + assertEquals(5, unmappedIndexFields.size()); + + // Verify unmapped field aliases + List filteredUnmappedFieldAliases = (List) respMap.get("unmapped_field_aliases"); + assertEquals(190, filteredUnmappedFieldAliases.size()); + assert(!filteredUnmappedFieldAliases.contains("winlog.event_data.LogonType")); + assert(!filteredUnmappedFieldAliases.contains("winlog.provider_name")); + assert(!filteredUnmappedFieldAliases.contains("host.hostname")); + assert(!filteredUnmappedFieldAliases.contains("winlog.event_id")); + } + public void testCreateMappings_withDatastream_success() throws IOException { String datastream = "test_datastream"; @@ -1278,6 +1386,69 @@ private void createSampleIndex(String indexName, Settings settings, String alias assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); } + private void createSampleWindex(String indexName) throws IOException { + createSampleWindex(indexName, Settings.EMPTY, null); + } + + private void createSampleWindex(String indexName, Settings settings, String aliases) throws IOException { + String indexMapping = + " \"properties\": {" + + " \"LogonType\": {" + + " \"type\": \"integer\"" + + " }," + + " \"Provider\": {" + + " \"type\": \"text\"" + + " }," + + " \"hostname\": {" + + " \"type\": \"text\"" + + " }," + + " \"plain1\": {" + + " \"type\": \"integer\"" + + " }," + + " \"ParentUser\":{" + + " \"type\":\"nested\"," + + " \"properties\":{" + + " \"first\":{" + + " \"type\":\"text\"," + + " \"fields\":{" + + " \"keyword\":{" + + " \"type\":\"keyword\"," + + " \"ignore_above\":256" + + "}" + + "}" + + "}," + + " \"last\":{" + + "\"type\":\"text\"," + + "\"fields\":{" + + " \"keyword\":{" + + " \"type\":\"keyword\"," + + " \"ignore_above\":256" + + "}" + + "}" + + "}" + + "}" + + "}" + + " }"; + + createIndex(indexName, settings, indexMapping, aliases); + + // Insert sample doc with event_uid not explicitly mapped + String sampleDoc = "{" + + " \"LogonType\":1," + + " \"Provider\":\"Microsoft-Windows-Security-Auditing\"," + + " \"hostname\":\"FLUXCAPACITOR\"" + + "}"; + + // Index doc + Request indexRequest = new Request("POST", indexName + "/_doc?refresh=wait_for"); + indexRequest.setJsonEntity(sampleDoc); + Response response = client().performRequest(indexRequest); + assertEquals(HttpStatus.SC_CREATED, response.getStatusLine().getStatusCode()); + // Refresh everything + response = client().performRequest(new Request("POST", "_refresh")); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + } + private void createSampleDatastream(String datastreamName) throws IOException { String indexMapping = " \"properties\": {" + diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/OCSFDetectorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/OCSFDetectorRestApiIT.java index 812e5eebd..0de2322fd 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/OCSFDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/OCSFDetectorRestApiIT.java @@ -436,7 +436,7 @@ public void testOCSFCloudtrailGetMappingsViewApi() throws IOException { assertEquals(20, unmappedIndexFields.size()); // Verify unmapped field aliases List unmappedFieldAliases = (List) respMap.get("unmapped_field_aliases"); - assertEquals(25, unmappedFieldAliases.size()); + assertEquals(24, unmappedFieldAliases.size()); } @SuppressWarnings("unchecked") @@ -502,7 +502,7 @@ public void testRawCloudtrailGetMappingsViewApi() throws IOException { assertEquals(17, unmappedIndexFields.size()); // Verify unmapped field aliases List unmappedFieldAliases = (List) respMap.get("unmapped_field_aliases"); - assertEquals(26, unmappedFieldAliases.size()); + assertEquals(25, unmappedFieldAliases.size()); } @SuppressWarnings("unchecked") From 656a5fecbc2e09212b2d621c86b525fcbf9e4086 Mon Sep 17 00:00:00 2001 From: Joanne Wang Date: Thu, 7 Mar 2024 17:24:57 -0800 Subject: [PATCH 11/12] Add an "exists" check for "not" condition in sigma rules (#852) * test design Signed-off-by: Joanne Wang * working version Signed-off-by: Joanne Wang * cleaning up Signed-off-by: Joanne Wang * testing Signed-off-by: Joanne Wang * working version Signed-off-by: Joanne Wang * working version Signed-off-by: Joanne Wang * refactored querybackend Signed-off-by: Joanne Wang * working on tests Signed-off-by: Joanne Wang * fixed alerting and finding tests Signed-off-by: Joanne Wang * fix correlation tests Signed-off-by: Joanne Wang * working all tests Signed-off-by: Joanne Wang * moved test and changed alias for adldap Signed-off-by: Joanne Wang * added more tests Signed-off-by: Joanne Wang * cleanup code Signed-off-by: Joanne Wang * remove exists flag Signed-off-by: Joanne Wang --------- Signed-off-by: Joanne Wang --- .../rules/backend/OSQueryBackend.java | 142 ++++++--- .../rules/backend/QueryBackend.java | 102 +++--- .../securityanalytics/TestHelpers.java | 147 ++++++++- .../securityanalytics/alerts/AlertsIT.java | 3 +- .../CorrelationEngineRestApiIT.java | 2 +- .../securityanalytics/findings/FindingIT.java | 298 +++++++++++++++++- .../rules/backend/QueryBackendTests.java | 174 +++++++++- 7 files changed, 779 insertions(+), 89 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java index ec7b09505..50d452f6b 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/OSQueryBackend.java @@ -48,7 +48,6 @@ import java.util.Map; public class OSQueryBackend extends QueryBackend { - private String tokenSeparator; private String orToken; @@ -57,6 +56,8 @@ public class OSQueryBackend extends QueryBackend { private String notToken; + private String existsToken; + private String escapeChar; private String wildcardMulti; @@ -119,6 +120,7 @@ public OSQueryBackend(Map fieldMappings, boolean collectErrors, this.orToken = "OR"; this.andToken = "AND"; this.notToken = "NOT"; + this.existsToken = "_exists_"; this.escapeChar = "\\"; this.wildcardMulti = "*"; this.wildcardSingle = "?"; @@ -145,15 +147,15 @@ public OSQueryBackend(Map fieldMappings, boolean collectErrors, } @Override - public Object convertConditionAsInExpression(Either condition) { + public Object convertConditionAsInExpression(Either condition, boolean isConditionNot, boolean applyDeMorgans) { if (condition.isLeft()) { - return this.convertConditionAnd(condition.getLeft()); + return this.convertConditionAnd(condition.getLeft(), isConditionNot, applyDeMorgans); } - return this.convertConditionOr(condition.get()); + return this.convertConditionOr(condition.get(), isConditionNot, applyDeMorgans); } @Override - public Object convertConditionAnd(ConditionAND condition) { + public Object convertConditionAnd(ConditionAND condition, boolean isConditionNot, boolean applyDeMorgans) { try { StringBuilder queryBuilder = new StringBuilder(); StringBuilder joiner = new StringBuilder(); @@ -171,21 +173,29 @@ public Object convertConditionAnd(ConditionAND condition) { ConditionType argType = arg.getLeft().getLeft().getClass().equals(ConditionAND.class)? new ConditionType(Either.left(AnyOneOf.leftVal((ConditionAND) arg.getLeft().getLeft()))): (arg.getLeft().getLeft().getClass().equals(ConditionOR.class)? new ConditionType(Either.left(AnyOneOf.middleVal((ConditionOR) arg.getLeft().getLeft()))): new ConditionType(Either.left(AnyOneOf.rightVal((ConditionNOT) arg.getLeft().getLeft())))); - converted = this.convertConditionGroup(argType); + converted = this.convertConditionGroup(argType, isConditionNot,applyDeMorgans ); } else if (arg.getLeft().isMiddle()) { - converted = this.convertConditionGroup(new ConditionType(Either.right(Either.left(arg.getLeft().getMiddle())))); + converted = this.convertConditionGroup(new ConditionType(Either.right(Either.left(arg.getLeft().getMiddle()))), isConditionNot, applyDeMorgans); } else if (arg.getLeft().isRight()) { - converted = this.convertConditionGroup(new ConditionType(Either.right(Either.right(arg.getLeft().get())))); + converted = this.convertConditionGroup(new ConditionType(Either.right(Either.right(arg.getLeft().get()))), isConditionNot, applyDeMorgans); } if (converted != null) { + // if applyDeMorgans is true, then use OR instead of AND + if (applyDeMorgans) { + joiner.setLength(0); // clear the joiner to convert it to OR + if (this.tokenSeparator.equals(this.andToken)) { + joiner.append(this.orToken); + } else { + joiner.append(this.tokenSeparator).append(this.orToken).append(this.tokenSeparator); + } + } if (!first) { queryBuilder.append(joiner).append(converted); } else { queryBuilder.append(converted); first = false; } - } } } @@ -196,7 +206,7 @@ public Object convertConditionAnd(ConditionAND condition) { } @Override - public Object convertConditionOr(ConditionOR condition) { + public Object convertConditionOr(ConditionOR condition, boolean isConditionNot, boolean applyDeMorgans) { try { StringBuilder queryBuilder = new StringBuilder(); StringBuilder joiner = new StringBuilder(); @@ -214,32 +224,41 @@ public Object convertConditionOr(ConditionOR condition) { ConditionType argType = arg.getLeft().getLeft().getClass().equals(ConditionAND.class)? new ConditionType(Either.left(AnyOneOf.leftVal((ConditionAND) arg.getLeft().getLeft()))): (arg.getLeft().getLeft().getClass().equals(ConditionOR.class)? new ConditionType(Either.left(AnyOneOf.middleVal((ConditionOR) arg.getLeft().getLeft()))): new ConditionType(Either.left(AnyOneOf.rightVal((ConditionNOT) arg.getLeft().getLeft())))); - converted = this.convertConditionGroup(argType); + converted = this.convertConditionGroup(argType, isConditionNot, applyDeMorgans); } else if (arg.getLeft().isMiddle()) { - converted = this.convertConditionGroup(new ConditionType(Either.right(Either.left(arg.getLeft().getMiddle())))); + converted = this.convertConditionGroup(new ConditionType(Either.right(Either.left(arg.getLeft().getMiddle()))), isConditionNot, applyDeMorgans); } else if (arg.getLeft().isRight()) { - converted = this.convertConditionGroup(new ConditionType(Either.right(Either.right(arg.getLeft().get())))); + converted = this.convertConditionGroup(new ConditionType(Either.right(Either.right(arg.getLeft().get()))), isConditionNot, applyDeMorgans); } if (converted != null) { + // if applyDeMorgans is true, then use AND instead of OR + if (applyDeMorgans) { + joiner.setLength(0); // clear the joiner to convert it to AND + if (this.tokenSeparator.equals(this.orToken)) { + joiner.append(this.andToken); + } else { + joiner.append(this.tokenSeparator).append(this.andToken).append(this.tokenSeparator); + } + } + if (!first) { queryBuilder.append(joiner).append(converted); } else { queryBuilder.append(converted); first = false; } - } } } return queryBuilder.toString(); } catch (Exception ex) { - throw new NotImplementedException("Operator 'and' not supported by the backend"); + throw new NotImplementedException("Operator 'or' not supported by the backend"); } } @Override - public Object convertConditionNot(ConditionNOT condition) { + public Object convertConditionNot(ConditionNOT condition, boolean isConditionNot, boolean applyDeMorgans) { Either, String> arg = condition.getArgs().get(0); try { if (arg.isLeft()) { @@ -247,13 +266,13 @@ public Object convertConditionNot(ConditionNOT condition) { ConditionType argType = arg.getLeft().getLeft().getClass().equals(ConditionAND.class) ? new ConditionType(Either.left(AnyOneOf.leftVal((ConditionAND) arg.getLeft().getLeft()))) : (arg.getLeft().getLeft().getClass().equals(ConditionOR.class) ? new ConditionType(Either.left(AnyOneOf.middleVal((ConditionOR) arg.getLeft().getLeft()))) : new ConditionType(Either.left(AnyOneOf.rightVal((ConditionNOT) arg.getLeft().getLeft())))); - return String.format(Locale.getDefault(), groupExpression, this.notToken + this.tokenSeparator + this.convertConditionGroup(argType)); + return String.format(Locale.getDefault(), groupExpression, this.convertConditionGroup(argType, true, true)); } else if (arg.getLeft().isMiddle()) { ConditionType argType = new ConditionType(Either.right(Either.left(arg.getLeft().getMiddle()))); - return String.format(Locale.getDefault(), groupExpression, this.notToken + this.tokenSeparator + this.convertCondition(argType).toString()); + return String.format(Locale.getDefault(), groupExpression, this.notToken + this.tokenSeparator + this.convertCondition(argType, true, applyDeMorgans).toString()); } else { ConditionType argType = new ConditionType(Either.right(Either.right(arg.getLeft().get()))); - return String.format(Locale.getDefault(), groupExpression, this.notToken + this.tokenSeparator + this.convertCondition(argType).toString()); + return String.format(Locale.getDefault(), groupExpression, this.notToken + this.tokenSeparator + this.convertCondition(argType, true, applyDeMorgans).toString()); } } } catch (Exception ex) { @@ -263,56 +282,89 @@ public Object convertConditionNot(ConditionNOT condition) { } @Override - public Object convertConditionFieldEqValStr(ConditionFieldEqualsValueExpression condition) throws SigmaValueError { + public Object convertExistsField(ConditionFieldEqualsValueExpression condition) { + String field = getFinalField(condition.getField()); + return String.format(Locale.getDefault(),tokenSeparator + this.andToken + this.tokenSeparator + this.existsToken + this.eqToken + this.tokenSeparator + field); + } + + @Override + public Object convertConditionFieldEqValStr(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) throws SigmaValueError { SigmaString value = (SigmaString) condition.getValue(); boolean containsWildcard = value.containsWildcard(); String expr = "%s" + this.eqToken + " " + (containsWildcard? this.reQuote: this.strQuote) + "%s" + (containsWildcard? this.reQuote: this.strQuote); + String exprWithDeMorgansApplied = this.notToken + " " + "%s" + this.eqToken + " " + (containsWildcard? this.reQuote: this.strQuote) + "%s" + (containsWildcard? this.reQuote: this.strQuote); String field = getFinalField(condition.getField()); ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer")); - return String.format(Locale.getDefault(), expr, field, this.convertValueStr(value)); + String convertedExpr = String.format(Locale.getDefault(), expr, field, this.convertValueStr(value)); + if (applyDeMorgans) { + convertedExpr = String.format(Locale.getDefault(), exprWithDeMorgansApplied, field, this.convertValueStr(value)); + } + return convertedExpr; } @Override - public Object convertConditionFieldEqValNum(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValNum(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) { String field = getFinalField(condition.getField()); SigmaNumber number = (SigmaNumber) condition.getValue(); ruleQueryFields.put(field, number.getNumOpt().isLeft()? Collections.singletonMap("type", "integer"): Collections.singletonMap("type", "float")); - + if (applyDeMorgans) { + return this.notToken + " " +field + this.eqToken + " " + condition.getValue(); + } return field + this.eqToken + " " + condition.getValue(); } @Override - public Object convertConditionFieldEqValBool(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValBool(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) { String field = getFinalField(condition.getField()); ruleQueryFields.put(field, Collections.singletonMap("type", "boolean")); - + if (applyDeMorgans) { + return this.notToken + " " + field + this.eqToken + " " + ((SigmaBool) condition.getValue()).isaBoolean(); + } return field + this.eqToken + " " + ((SigmaBool) condition.getValue()).isaBoolean(); } - public Object convertConditionFieldEqValNull(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValNull(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) { String field = getFinalField(condition.getField()); ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer")); + String exprWithDeMorgansApplied = this.notToken + " " + this.fieldNullExpression; + if (applyDeMorgans) { + return String.format(Locale.getDefault(), exprWithDeMorgansApplied, field); + } return String.format(Locale.getDefault(), this.fieldNullExpression, field); } @Override - public Object convertConditionFieldEqValRe(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValRe(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) { String field = getFinalField(condition.getField()); ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer")); + String exprWithDeMorgansApplied = this.notToken + " " + this.reExpression; + if (applyDeMorgans) { + return String.format(Locale.getDefault(), exprWithDeMorgansApplied, field, convertValueRe((SigmaRegularExpression) condition.getValue())); + } return String.format(Locale.getDefault(), this.reExpression, field, convertValueRe((SigmaRegularExpression) condition.getValue())); } @Override - public Object convertConditionFieldEqValCidr(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValCidr(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) { String field = getFinalField(condition.getField()); ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer")); + String exprWithDeMorgansApplied = this.notToken + " " + this.cidrExpression; + if (applyDeMorgans) { + return String.format(Locale.getDefault(), exprWithDeMorgansApplied, field, convertValueCidr((SigmaCIDRExpression) condition.getValue())); + } return String.format(Locale.getDefault(), this.cidrExpression, field, convertValueCidr((SigmaCIDRExpression) condition.getValue())); } @Override - public Object convertConditionFieldEqValOpVal(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValOpVal(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) { + String exprWithDeMorgansApplied = this.notToken + " " + this.compareOpExpression; + if (applyDeMorgans) { + return String.format(Locale.getDefault(), exprWithDeMorgansApplied, this.getMappedField(condition.getField()), + compareOperators.get(((SigmaCompareExpression) condition.getValue()).getOp()), ((SigmaCompareExpression) condition.getValue()).getNumber().toString()); + } + return String.format(Locale.getDefault(), this.compareOpExpression, this.getMappedField(condition.getField()), compareOperators.get(((SigmaCompareExpression) condition.getValue()).getOp()), ((SigmaCompareExpression) condition.getValue()).getNumber().toString()); } @@ -330,23 +382,39 @@ public Object convertConditionFieldEqValQueryExpr(ConditionFieldEqualsValueExpre }*/ @Override - public Object convertConditionValStr(ConditionValueExpression condition) throws SigmaValueError { + public Object convertConditionValStr(ConditionValueExpression condition, boolean applyDeMorgans) throws SigmaValueError { String field = getFinalValueField(); ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer")); SigmaString value = (SigmaString) condition.getValue(); boolean containsWildcard = value.containsWildcard(); - return String.format(Locale.getDefault(), (containsWildcard? this.unboundWildcardExpression: this.unboundValueStrExpression), + String exprWithDeMorgansApplied = this.notToken + " " + "%s"; + + String conditionValStr = String.format(Locale.getDefault(), (containsWildcard? this.unboundWildcardExpression: this.unboundValueStrExpression), this.convertValueStr((SigmaString) condition.getValue())); + if (applyDeMorgans) { + conditionValStr = String.format(Locale.getDefault(), exprWithDeMorgansApplied, conditionValStr); + } + return conditionValStr; } @Override - public Object convertConditionValNum(ConditionValueExpression condition) { - return String.format(Locale.getDefault(), this.unboundValueNumExpression, condition.getValue().toString()); + public Object convertConditionValNum(ConditionValueExpression condition, boolean applyDeMorgans) { + String exprWithDeMorgansApplied = this.notToken + " " + "%s"; + String conditionValNum = String.format(Locale.getDefault(), String.format(Locale.getDefault(), this.unboundValueNumExpression, condition.getValue().toString())); + if (applyDeMorgans) { + conditionValNum = String.format(Locale.getDefault(), exprWithDeMorgansApplied, conditionValNum); + } + return conditionValNum; } @Override - public Object convertConditionValRe(ConditionValueExpression condition) { - return String.format(Locale.getDefault(), this.unboundReExpression, convertValueRe((SigmaRegularExpression) condition.getValue())); + public Object convertConditionValRe(ConditionValueExpression condition, boolean applyDeMorgans) { + String exprWithDeMorgansApplied = this.notToken + " " + "%s"; + String conditionValStr = String.format(Locale.getDefault(), this.unboundReExpression, convertValueRe((SigmaRegularExpression) condition.getValue())); + if (applyDeMorgans) { + conditionValStr = String.format(Locale.getDefault(), exprWithDeMorgansApplied, conditionValStr); + } + return conditionValStr; } // TODO: below methods will be supported when Sigma Expand Modifier is supported. @@ -421,8 +489,8 @@ private boolean comparePrecedence(ConditionType outer, ConditionType inner) { return idxInner <= precedence.indexOf(outerClass); } - private Object convertConditionGroup(ConditionType condition) throws SigmaValueError { - return String.format(Locale.getDefault(), groupExpression, this.convertCondition(condition)); + private Object convertConditionGroup(ConditionType condition, boolean isConditionNot, boolean applyDeMorgans) throws SigmaValueError { + return String.format(Locale.getDefault(), groupExpression, this.convertCondition(condition, isConditionNot, applyDeMorgans)); } private Object convertValueStr(SigmaString s) throws SigmaValueError { diff --git a/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java b/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java index c63dce05d..2c56a2c6a 100644 --- a/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java +++ b/src/main/java/org/opensearch/securityanalytics/rules/backend/QueryBackend.java @@ -4,8 +4,6 @@ */ package org.opensearch.securityanalytics.rules.backend; -import org.opensearch.commons.alerting.aggregation.bucketselectorext.BucketSelectorExtAggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.securityanalytics.rules.aggregation.AggregationItem; import org.opensearch.securityanalytics.rules.backend.OSQueryBackend.AggregationQueries; import org.opensearch.securityanalytics.rules.condition.ConditionAND; @@ -47,7 +45,6 @@ import java.util.Set; public abstract class QueryBackend { - private boolean convertOrAsIn; private boolean convertAndAsIn; private boolean collectErrors; @@ -85,15 +82,15 @@ public List convertRule(SigmaRule rule) throws SigmaError { Object query; if (conditionItem instanceof ConditionAND) { - query = this.convertCondition(new ConditionType(Either.left(AnyOneOf.leftVal((ConditionAND) conditionItem)))); + query = this.convertCondition(new ConditionType(Either.left(AnyOneOf.leftVal((ConditionAND) conditionItem))), false, false); } else if (conditionItem instanceof ConditionOR) { - query = this.convertCondition(new ConditionType(Either.left(AnyOneOf.middleVal((ConditionOR) conditionItem)))); + query = this.convertCondition(new ConditionType(Either.left(AnyOneOf.middleVal((ConditionOR) conditionItem))), false, false); } else if (conditionItem instanceof ConditionNOT) { - query = this.convertCondition(new ConditionType(Either.left(AnyOneOf.rightVal((ConditionNOT) conditionItem)))); + query = this.convertCondition(new ConditionType(Either.left(AnyOneOf.rightVal((ConditionNOT) conditionItem))), true, false); } else if (conditionItem instanceof ConditionFieldEqualsValueExpression) { - query = this.convertCondition(new ConditionType(Either.right(Either.left((ConditionFieldEqualsValueExpression) conditionItem)))); + query = this.convertCondition(new ConditionType(Either.right(Either.left((ConditionFieldEqualsValueExpression) conditionItem))), false, false); } else { - query = this.convertCondition(new ConditionType(Either.right(Either.right((ConditionValueExpression) conditionItem)))); + query = this.convertCondition(new ConditionType(Either.right(Either.right((ConditionValueExpression) conditionItem))), false, false); } queries.add(query); if (aggItem != null) { @@ -113,30 +110,41 @@ public List convertRule(SigmaRule rule) throws SigmaError { return queries; } - public Object convertCondition(ConditionType conditionType) throws SigmaValueError { + public Object convertCondition(ConditionType conditionType, boolean isConditionNot, boolean applyDeMorgans) throws SigmaValueError { if (conditionType.isConditionOR()) { if (this.decideConvertConditionAsInExpression(Either.right(conditionType.getConditionOR()))) { - return this.convertConditionAsInExpression(Either.right(conditionType.getConditionOR())); + return this.convertConditionAsInExpression(Either.right(conditionType.getConditionOR()), isConditionNot, applyDeMorgans ); } else { - return this.convertConditionOr(conditionType.getConditionOR()); + return this.convertConditionOr(conditionType.getConditionOR(), isConditionNot, applyDeMorgans); } } else if (conditionType.isConditionAND()) { if (this.decideConvertConditionAsInExpression(Either.left(conditionType.getConditionAND()))) { - return this.convertConditionAsInExpression(Either.left(conditionType.getConditionAND())); + return this.convertConditionAsInExpression(Either.left(conditionType.getConditionAND()), isConditionNot, applyDeMorgans); } else { - return this.convertConditionAnd(conditionType.getConditionAND()); + return this.convertConditionAnd(conditionType.getConditionAND(), isConditionNot, applyDeMorgans); } } else if (conditionType.isConditionNOT()) { - return this.convertConditionNot(conditionType.getConditionNOT()); + return this.convertConditionNot(conditionType.getConditionNOT(), isConditionNot, applyDeMorgans); } else if (conditionType.isEqualsValueExpression()) { - return this.convertConditionFieldEqVal(conditionType.getEqualsValueExpression()); + // check to see if conditionNot is an ancestor of the parse tree, otherwise return as normal + if (isConditionNot) { + return this.convertConditionFieldEqValNot(conditionType, isConditionNot, applyDeMorgans); + } else { + return this.convertConditionFieldEqVal(conditionType.getEqualsValueExpression(), isConditionNot, applyDeMorgans); + } } else if (conditionType.isValueExpression()) { - return this.convertConditionVal(conditionType.getValueExpression()); + return this.convertConditionVal(conditionType.getValueExpression(), applyDeMorgans); } else { throw new IllegalArgumentException("Unexpected data type in condition parse tree"); } } + public String convertConditionFieldEqValNot(ConditionType conditionType, boolean isConditionNot, boolean applyDeMorgans) throws SigmaValueError { + String baseString = this.convertConditionFieldEqVal(conditionType.getEqualsValueExpression(), isConditionNot, applyDeMorgans).toString(); + String addExists = this.convertExistsField(conditionType.getEqualsValueExpression()).toString(); + return String.format(Locale.getDefault(), ("%s" + "%s"), baseString, addExists); + } + public boolean decideConvertConditionAsInExpression(Either condition) { if ((!this.convertOrAsIn && condition.isRight()) || (!this.convertAndAsIn && condition.isLeft())) { return false; @@ -181,74 +189,76 @@ public void resetQueryFields() { } } - public abstract Object convertConditionAsInExpression(Either condition); + public abstract Object convertConditionAsInExpression(Either condition, boolean isConditionNot, boolean applyDeMorgans); - public abstract Object convertConditionAnd(ConditionAND condition); + public abstract Object convertConditionAnd(ConditionAND condition, boolean isConditionNot, boolean applyDeMorgans); - public abstract Object convertConditionOr(ConditionOR condition); + public abstract Object convertConditionOr(ConditionOR condition, boolean isConditionNot, boolean applyDeMorgans); - public abstract Object convertConditionNot(ConditionNOT condition); + public abstract Object convertConditionNot(ConditionNOT condition, boolean isConditionNot, boolean applyDeMorgans); - public Object convertConditionFieldEqVal(ConditionFieldEqualsValueExpression condition) throws SigmaValueError { + public Object convertConditionFieldEqVal(ConditionFieldEqualsValueExpression condition, boolean isConditionNot, boolean applyDeMorgans) throws SigmaValueError { if (condition.getValue() instanceof SigmaString) { - return this.convertConditionFieldEqValStr(condition); + return this.convertConditionFieldEqValStr(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaNumber) { - return this.convertConditionFieldEqValNum(condition); + return this.convertConditionFieldEqValNum(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaBool) { - return this.convertConditionFieldEqValBool(condition); + return this.convertConditionFieldEqValBool(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaRegularExpression) { - return this.convertConditionFieldEqValRe(condition); + return this.convertConditionFieldEqValRe(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaCIDRExpression) { - return this.convertConditionFieldEqValCidr(condition); + return this.convertConditionFieldEqValCidr(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaCompareExpression) { - return this.convertConditionFieldEqValOpVal(condition); + return this.convertConditionFieldEqValOpVal(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaNull) { - return this.convertConditionFieldEqValNull(condition); + return this.convertConditionFieldEqValNull(condition, applyDeMorgans); }/* TODO: below methods will be supported when Sigma Expand Modifier is supported. else if (condition.getValue() instanceof SigmaQueryExpression) { return this.convertConditionFieldEqValQueryExpr(condition); }*/ else if (condition.getValue() instanceof SigmaExpansion) { - return this.convertConditionFieldEqValQueryExpansion(condition); + return this.convertConditionFieldEqValQueryExpansion(condition, isConditionNot, applyDeMorgans); } else { throw new IllegalArgumentException("Unexpected value type class in condition parse tree: " + condition.getValue().getClass().getName()); } } - public abstract Object convertConditionFieldEqValStr(ConditionFieldEqualsValueExpression condition) throws SigmaValueError; + public abstract Object convertConditionFieldEqValStr(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans) throws SigmaValueError; + + public abstract Object convertConditionFieldEqValNum(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionFieldEqValNum(ConditionFieldEqualsValueExpression condition); + public abstract Object convertConditionFieldEqValBool(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionFieldEqValBool(ConditionFieldEqualsValueExpression condition); + public abstract Object convertConditionFieldEqValRe(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionFieldEqValRe(ConditionFieldEqualsValueExpression condition); + public abstract Object convertConditionFieldEqValCidr(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionFieldEqValCidr(ConditionFieldEqualsValueExpression condition); + public abstract Object convertConditionFieldEqValOpVal(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionFieldEqValOpVal(ConditionFieldEqualsValueExpression condition); + public abstract Object convertConditionFieldEqValNull(ConditionFieldEqualsValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionFieldEqValNull(ConditionFieldEqualsValueExpression condition); + public abstract Object convertExistsField(ConditionFieldEqualsValueExpression condition); -/* public abstract Object convertConditionFieldEqValQueryExpr(ConditionFieldEqualsValueExpression condition);*/ + /* public abstract Object convertConditionFieldEqValQueryExpr(ConditionFieldEqualsValueExpression condition);*/ - public Object convertConditionFieldEqValQueryExpansion(ConditionFieldEqualsValueExpression condition) { + public Object convertConditionFieldEqValQueryExpansion(ConditionFieldEqualsValueExpression condition, boolean isConditionNot, boolean applyDeMorgans) { List, String>> args = new ArrayList<>(); for (SigmaType sigmaType: ((SigmaExpansion) condition.getValue()).getValues()) { args.add(Either.left(AnyOneOf.middleVal(new ConditionFieldEqualsValueExpression(condition.getField(), sigmaType)))); } ConditionOR conditionOR = new ConditionOR(false, args); - return this.convertConditionOr(conditionOR); + return this.convertConditionOr(conditionOR, isConditionNot, applyDeMorgans); } - public Object convertConditionVal(ConditionValueExpression condition) throws SigmaValueError { + public Object convertConditionVal(ConditionValueExpression condition, boolean applyDeMorgans) throws SigmaValueError { if (condition.getValue() instanceof SigmaString) { - return this.convertConditionValStr(condition); + return this.convertConditionValStr(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaNumber) { - return this.convertConditionValNum(condition); + return this.convertConditionValNum(condition, applyDeMorgans); } else if (condition.getValue() instanceof SigmaBool) { throw new SigmaValueError("Boolean values can't appear as standalone value without a field name."); } else if (condition.getValue() instanceof SigmaRegularExpression) { - return this.convertConditionValRe(condition); + return this.convertConditionValRe(condition, applyDeMorgans); }/* else if (condition.getValue() instanceof SigmaCIDRExpression) { throw new SigmaValueError("CIDR values can't appear as standalone value without a field name."); } else if (condition.getValue() instanceof SigmaQueryExpression) { @@ -258,11 +268,11 @@ public Object convertConditionVal(ConditionValueExpression condition) throws Sig } } - public abstract Object convertConditionValStr(ConditionValueExpression condition) throws SigmaValueError; + public abstract Object convertConditionValStr(ConditionValueExpression condition, boolean applyDeMorgans) throws SigmaValueError; - public abstract Object convertConditionValNum(ConditionValueExpression condition); + public abstract Object convertConditionValNum(ConditionValueExpression condition, boolean applyDeMorgans); - public abstract Object convertConditionValRe(ConditionValueExpression condition); + public abstract Object convertConditionValRe(ConditionValueExpression condition, boolean applyDeMorgans); /* public abstract Object convertConditionValQueryExpr(ConditionValueExpression condition);*/ diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index d907b797c..2902dbaa7 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -259,6 +259,72 @@ public static String randomRule() { "level: high"; } + public static String randomRuleWithNotCondition() { + return "title: Remote Encrypting File System Abuse\n" + + "id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" + + "description: Detects remote RPC calls to possibly abuse remote encryption service via MS-EFSR\n" + + "references:\n" + + " - https://attack.mitre.org/tactics/TA0008/\n" + + " - https://msrc.microsoft.com/update-guide/vulnerability/CVE-2021-36942\n" + + " - https://github.com/jsecurity101/MSRPC-to-ATTACK/blob/main/documents/MS-EFSR.md\n" + + " - https://github.com/zeronetworks/rpcfirewall\n" + + " - https://zeronetworks.com/blog/stopping_lateral_movement_via_the_rpc_firewall/\n" + + "tags:\n" + + " - attack.defense_evasion\n" + + "status: experimental\n" + + "author: Sagie Dulce, Dekel Paz\n" + + "date: 2022/01/01\n" + + "modified: 2022/01/01\n" + + "logsource:\n" + + " product: rpc_firewall\n" + + " category: application\n" + + " definition: 'Requirements: install and apply the RPC Firewall to all processes with \"audit:true action:block uuid:df1941c5-fe89-4e79-bf10-463657acf44d or c681d488-d850-11d0-8c52-00c04fd90f7e'\n" + + "detection:\n" + + " selection1:\n" + + " AccountType: TestAccountType\n" + + " selection2:\n" + + " AccountName: TestAccountName\n" + + " selection3:\n" + + " EventID: 22\n" + + " condition: (not selection1 and not selection2) and selection3\n" + + "falsepositives:\n" + + " - Legitimate usage of remote file encryption\n" + + "level: high"; + } + + public static String randomRuleWithNotConditionBoolAndNum() { + return "title: Remote Encrypting File System Abuse\n" + + "id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" + + "description: Detects remote RPC calls to possibly abuse remote encryption service via MS-EFSR\n" + + "references:\n" + + " - https://attack.mitre.org/tactics/TA0008/\n" + + " - https://msrc.microsoft.com/update-guide/vulnerability/CVE-2021-36942\n" + + " - https://github.com/jsecurity101/MSRPC-to-ATTACK/blob/main/documents/MS-EFSR.md\n" + + " - https://github.com/zeronetworks/rpcfirewall\n" + + " - https://zeronetworks.com/blog/stopping_lateral_movement_via_the_rpc_firewall/\n" + + "tags:\n" + + " - attack.defense_evasion\n" + + "status: experimental\n" + + "author: Sagie Dulce, Dekel Paz\n" + + "date: 2022/01/01\n" + + "modified: 2022/01/01\n" + + "logsource:\n" + + " product: rpc_firewall\n" + + " category: application\n" + + " definition: 'Requirements: install and apply the RPC Firewall to all processes with \"audit:true action:block uuid:df1941c5-fe89-4e79-bf10-463657acf44d or c681d488-d850-11d0-8c52-00c04fd90f7e'\n" + + "detection:\n" + + " selection1:\n" + + " Initiated: \"false\"\n" + + " selection2:\n" + + " AccountName: TestAccountName\n" + + " selection3:\n" + + " EventID: 21\n" + + " condition: not selection1 and not selection3\n" + + "falsepositives:\n" + + " - Legitimate usage of remote file encryption\n" + + "level: high"; + } + public static String randomNullRule() { return "title: null field\n" + "id: 5f92fff9-82e2-48eb-8fc1-8b133556a551\n" + @@ -1701,6 +1767,44 @@ public static String randomDoc(int severity, int version, String opCode) { } + public static String randomDocForNotCondition(int severity, int version, String opCode) { + String doc = "{\n" + + "\"EventTime\":\"2020-02-04T14:59:39.343541+00:00\",\n" + + "\"HostName\":\"EC2AMAZ-EPO7HKA\",\n" + + "\"Keywords\":\"9223372036854775808\",\n" + + "\"SeverityValue\":%s,\n" + + "\"Severity\":\"INFO\",\n" + + "\"EventID\":22,\n" + + "\"SourceName\":\"Microsoft-Windows-Sysmon\",\n" + + "\"ProviderGuid\":\"{5770385F-C22A-43E0-BF4C-06F5698FFBD9}\",\n" + + "\"Version\":%s,\n" + + "\"TaskValue\":22,\n" + + "\"OpcodeValue\":0,\n" + + "\"RecordNumber\":9532,\n" + + "\"ExecutionProcessID\":1996,\n" + + "\"ExecutionThreadID\":2616,\n" + + "\"Channel\":\"Microsoft-Windows-Sysmon/Operational\",\n" + + "\"Domain\":\"NT AUTHORITY\",\n" + + "\"UserID\":\"S-1-5-18\",\n" + + "\"AccountType\":\"User\",\n" + + "\"Message\":\"Dns query:\\r\\nRuleName: \\r\\nUtcTime: 2020-02-04 14:59:38.349\\r\\nProcessGuid: {b3c285a4-3cda-5dc0-0000-001077270b00}\\r\\nProcessId: 1904\\r\\nQueryName: EC2AMAZ-EPO7HKA\\r\\nQueryStatus: 0\\r\\nQueryResults: 172.31.46.38;\\r\\nImage: C:\\\\Program Files\\\\nxlog\\\\nxlog.exe\",\n" + + "\"Category\":\"Dns query (rule: DnsQuery)\",\n" + + "\"Opcode\":\"%s\",\n" + + "\"UtcTime\":\"2020-02-04 14:59:38.349\",\n" + + "\"ProcessGuid\":\"{b3c285a4-3cda-5dc0-0000-001077270b00}\",\n" + + "\"ProcessId\":\"1904\",\"QueryName\":\"EC2AMAZ-EPO7HKA\",\"QueryStatus\":\"0\",\n" + + "\"QueryResults\":\"172.31.46.38;\",\n" + + "\"Image\":\"C:\\\\Program Files\\\\nxlog\\\\regsvr32.exe\",\n" + + "\"EventReceivedTime\":\"2020-02-04T14:59:40.780905+00:00\",\n" + + "\"SourceModuleName\":\"in\",\n" + + "\"SourceModuleType\":\"im_msvistalog\",\n" + + "\"CommandLine\": \"eachtest\",\n" + + "\"Initiated\": \"true\"\n" + + "}"; + return String.format(Locale.ROOT, doc, severity, version, opCode); + + } + public static String randomDocOnlyNumericAndDate(int severity, int version, String opCode) { String doc = "{\n" + "\"EventTime\":\"2020-02-04T14:59:39.343541+00:00\",\n" + @@ -1840,6 +1944,46 @@ public static String randomDoc() { "}"; } + public static String randomNetworkDoc() { + return "{\n" + + "\"@timestamp\":\"2020-02-04T14:59:39.343541+00:00\",\n" + + "\"EventTime\":\"2020-02-04T14:59:39.343541+00:00\",\n" + + "\"HostName\":\"EC2AMAZ-EPO7HKA\",\n" + + "\"Keywords\":\"9223372036854775808\",\n" + + "\"SeverityValue\":2,\n" + + "\"Severity\":\"INFO\",\n" + + "\"EventID\":22,\n" + + "\"SourceName\":\"Microsoft-Windows-Sysmon\",\n" + + "\"SourceIp\":\"1.2.3.4\",\n" + + "\"ProviderGuid\":\"{5770385F-C22A-43E0-BF4C-06F5698FFBD9}\",\n" + + "\"Version\":5,\n" + + "\"TaskValue\":22,\n" + + "\"OpcodeValue\":0,\n" + + "\"RecordNumber\":9532,\n" + + "\"ExecutionProcessID\":1996,\n" + + "\"ExecutionThreadID\":2616,\n" + + "\"Channel\":\"Microsoft-Windows-Sysmon/Operational\",\n" + + "\"Domain\":\"NTAUTHORITY\",\n" + + "\"AccountName\":\"SYSTEM\",\n" + + "\"UserID\":\"S-1-5-18\",\n" + + "\"AccountType\":\"User\",\n" + + "\"Message\":\"Dns query:\\r\\nRuleName: \\r\\nUtcTime: 2020-02-04 14:59:38.349\\r\\nProcessGuid: {b3c285a4-3cda-5dc0-0000-001077270b00}\\r\\nProcessId: 1904\\r\\nQueryName: EC2AMAZ-EPO7HKA\\r\\nQueryStatus: 0\\r\\nQueryResults: 172.31.46.38;\\r\\nImage: C:\\\\Program Files\\\\nxlog\\\\nxlog.exe\",\n" + + "\"Category\":\"Dns query (rule: DnsQuery)\",\n" + + "\"Opcode\":\"Info\",\n" + + "\"UtcTime\":\"2020-02-04 14:59:38.349\",\n" + + "\"ProcessGuid\":\"{b3c285a4-3cda-5dc0-0000-001077270b00}\",\n" + + "\"ProcessId\":\"1904\",\"QueryName\":\"EC2AMAZ-EPO7HKA\",\"QueryStatus\":\"0\",\n" + + "\"QueryResults\":\"172.31.46.38;\",\n" + + "\"Image\":\"C:\\\\Program Files\\\\nxlog\\\\regsvr32.exe\",\n" + + "\"EventReceivedTime\":\"2020-02-04T14:59:40.780905+00:00\",\n" + + "\"SourceModuleName\":\"in\",\n" + + "\"SourceModuleType\":\"im_msvistalog\",\n" + + "\"CommandLine\": \"eachtest\",\n" + + "\"id.orig_h\": \"123.12.123.12\",\n" + + "\"Initiated\": \"true\"\n" + + "}"; + } + public static String randomCloudtrailAggrDoc(String eventType, String accountId) { return "{\n" + " \"AccountName\": \"" + accountId + "\",\n" + @@ -1857,6 +2001,7 @@ public static String randomVpcFlowDoc() { " \"srcport\": 9000,\n" + " \"dstport\": 8000,\n" + " \"severity_id\": \"-1\",\n" + + " \"id.orig_h\": \"1.2.3.4\",\n" + " \"class_name\": \"Network Activity\"\n" + "}"; } @@ -2432,7 +2577,7 @@ public static List randomLowerCaseStringList() { stringList.add(randomLowerCaseString()); return stringList; } - + public static XContentParser parser(String xc) throws IOException { XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, xc); parser.nextToken(); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index fbd091595..347fb66f1 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -43,6 +43,7 @@ import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; import static org.opensearch.securityanalytics.TestHelpers.randomDocWithIpIoc; +import static org.opensearch.securityanalytics.TestHelpers.randomNetworkDoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; import static org.opensearch.securityanalytics.TestHelpers.randomRule; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; @@ -545,7 +546,7 @@ public void testGetAlerts_byDetectorType_multipleDetectors_success() throws IOEx String monitorId2 = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); indexDoc(index1, "1", randomDoc()); - indexDoc(index2, "1", randomDoc()); + indexDoc(index2, "1", randomNetworkDoc()); // execute monitor 1 Response executeResponse = executeAlertingMonitor(monitorId1, Collections.emptyMap()); Map executeResults = entityAsMap(executeResponse); diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java index 149f8fd34..a4cdb6d1c 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java @@ -1109,7 +1109,7 @@ private String createAdLdapDetector(String indexName) throws IOException { " \"partial\": true,\n" + " \"alias_mappings\": {\n" + " \"properties\": {\n" + - " \"azure-signinlogs-properties-user_id\": {\n" + + " \"azure.signinlogs.properties.user_id\": {\n" + " \"path\": \"azure.signinlogs.props.user_id\",\n" + " \"type\": \"alias\"\n" + " },\n" + diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java index 3b7ca3c0a..1f7d112de 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java @@ -10,29 +10,42 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.HashSet; +import java.util.ArrayList; +import java.util.Arrays; import java.util.stream.Collectors; import org.apache.hc.core5.http.HttpStatus; import org.junit.Assert; import org.junit.Ignore; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.commons.alerting.model.Monitor; import org.opensearch.core.rest.RestStatus; import org.opensearch.search.SearchHit; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; +import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.DetectorTrigger; +import static java.util.Collections.emptyList; import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; +import static org.opensearch.securityanalytics.TestHelpers.randomRuleWithNotConditionBoolAndNum; +import static org.opensearch.securityanalytics.TestHelpers.randomNetworkDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomDocForNotCondition; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; +import static org.opensearch.securityanalytics.TestHelpers.randomRuleWithNotCondition; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.FINDING_HISTORY_INDEX_MAX_AGE; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.FINDING_HISTORY_MAX_DOCS; import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.FINDING_HISTORY_RETENTION_PERIOD; @@ -234,7 +247,7 @@ public void testGetFindings_byDetectorType_success() throws IOException { String monitorId2 = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); indexDoc(index1, "1", randomDoc()); - indexDoc(index2, "1", randomDoc()); + indexDoc(index2, "1", randomNetworkDoc()); // execute monitor 1 Response executeResponse = executeAlertingMonitor(monitorId1, Collections.emptyMap()); Map executeResults = entityAsMap(executeResponse); @@ -400,6 +413,289 @@ public void testGetFindings_rolloverByMaxDoc_success() throws IOException, Inter restoreAlertsFindingsIMSettings(); } + public void testCreateDetectorWithNotCondition_verifyFindings_success() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + // Create random doc rule + String randomDocRuleId = createRule(randomRuleWithNotCondition()); + List prepackagedRules = getRandomPrePackagedRules(); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(randomDocRuleId)), + prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map updateResponseBody = asMap(createResponse); + String detectorId = updateResponseBody.get("_id").toString(); + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + + // Verify newly created doc level monitor + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorAsMap = (Map) hit.getSourceAsMap().get("detector"); + List monitorIds = ((List) (detectorAsMap).get("monitor_id")); + + assertEquals(1, monitorIds.size()); + + String monitorId = monitorIds.get(0); + String monitorType = ((Map) entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId))).get("monitor")).get("monitor_type"); + + assertEquals(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue(), monitorType); + + // Verify rules + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(6, response.getHits().getTotalHits().value); + + // Verify findings + indexDoc(index, "1", randomDoc(2, 5, "Test")); + indexDoc(index, "2", randomDoc(3, 5, "Test")); + + + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + // Verify 5 prepackaged rules and 1 custom rule + assertEquals(6, noOfSigmaRuleMatches); + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + // When doc level monitor is being applied one finding is generated per document + assertEquals(2, getFindingsBody.get("total_findings")); + + Set docRuleIds = new HashSet<>(prepackagedRules); + docRuleIds.add(randomDocRuleId); + + List> findings = (List) getFindingsBody.get("findings"); + List foundDocIds = new ArrayList<>(); + for (Map finding : findings) { + Set aggRulesFinding = ((List>) finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( + Collectors.toSet()); + + assertTrue(docRuleIds.containsAll(aggRulesFinding)); + + List findingDocs = (List) finding.get("related_doc_ids"); + Assert.assertEquals(1, findingDocs.size()); + foundDocIds.addAll(findingDocs); + } + assertTrue(Arrays.asList("1", "2").containsAll(foundDocIds)); + } + + public void testCreateDetectorWithNotCondition_verifyFindings_success_boolAndNum() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + // Create random custom doc rule with NOT condition + String randomDocRuleId = createRule(randomRuleWithNotConditionBoolAndNum()); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(randomDocRuleId)), + emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + String monitorId = monitorIds.get(0); + + // Verify findings + indexDoc(index, "1", randomDoc(2, 5, "Test")); + indexDoc(index, "2", randomDoc(2, 5, "Test")); + + + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + + // Verify 1 custom rule + assertEquals(1, noOfSigmaRuleMatches); + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + // When doc level monitor is being applied one finding is generated per document + assertEquals(2, getFindingsBody.get("total_findings")); + + List> findings = (List) getFindingsBody.get("findings"); + List foundDocIds = new ArrayList<>(); + for (Map finding : findings) { + List findingDocs = (List) finding.get("related_doc_ids"); + Assert.assertEquals(1, findingDocs.size()); + foundDocIds.addAll(findingDocs); + } + assertTrue(Arrays.asList("1", "2").containsAll(foundDocIds)); + } + + /* + Create a detector with custom rules that include a "not" condition in the sigma rule. + Insert two test documents one matching the rule and one without the field matching the condition to generate only one finding + */ + public void testCreateDetectorWithNotCondition_verifyFindingsAndNoFindings_success() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + // Create random custom doc rule with NOT condition + String randomDocRuleId = createRule(randomRuleWithNotCondition()); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(randomDocRuleId)), + emptyList()); + Detector detector = randomDetectorWithInputs(List.of(input)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(1, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(1, monitorIds.size()); + + String monitorId = monitorIds.get(0); + + // Verify findings + indexDoc(index, "1", randomDoc(2, 5, "Test")); + indexDoc(index, "2", randomDocForNotCondition(2, 5, "Test")); + indexDoc(index, "3", randomDocForNotCondition(2, 5, "Test")); + indexDoc(index, "4", randomDoc(2, 5, "Test")); + + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + + // Verify 1 custom rule + assertEquals(1, noOfSigmaRuleMatches); + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + assertNotNull(getFindingsBody); + // When doc level monitor is being applied one finding is generated per document + assertEquals(2, getFindingsBody.get("total_findings")); + + List> findings = (List) getFindingsBody.get("findings"); + List foundDocIds = new ArrayList<>(); + for (Map finding : findings) { + List findingDocs = (List) finding.get("related_doc_ids"); + Assert.assertEquals(1, findingDocs.size()); + foundDocIds.addAll(findingDocs); + } + assertTrue(Arrays.asList("1", "4").containsAll(foundDocIds)); + } + public void testGetFindings_rolloverByMaxDoc_short_retention_success() throws IOException, InterruptedException { updateClusterSetting(FINDING_HISTORY_ROLLOVER_PERIOD.getKey(), "1s"); updateClusterSetting(FINDING_HISTORY_MAX_DOCS.getKey(), "1"); diff --git a/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java b/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java index 3f8196d3d..1ec872f88 100644 --- a/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java +++ b/src/test/java/org/opensearch/securityanalytics/rules/backend/QueryBackendTests.java @@ -714,7 +714,177 @@ public void testConvertNot() throws IOException, SigmaError { " sel:\n" + " fieldA: value1\n" + " condition: not sel", false)); - Assert.assertEquals("(NOT fieldA: \"value1\")", queries.get(0).toString()); + Assert.assertEquals("(NOT fieldA: \"value1\" AND _exists_: fieldA)", queries.get(0).toString()); + } + + public void testConvertNotWithParenthesis() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel1:\n" + + " Opcode: Info\n" + + " sel2:\n" + + " Severity: value2\n" + + " condition: not (sel1 or sel2)", false)); + Assert.assertEquals("(((NOT Opcode: \"Info\" AND _exists_: Opcode) AND (NOT Severity: \"value2\" AND _exists_: Severity)))", queries.get(0).toString()); + } + + public void testConvertNotComplicatedExpression() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " selection1:\n" + + " CommandLine|endswith: '.cpl'\n" + + " filter:\n" + + " CommandLine|contains:\n" + + " - '\\System32\\'\n" + + " - '%System%'\n" + + " fp1_igfx:\n" + + " CommandLine|contains|all:\n" + + " - 'regsvr32 '\n" + + " - ' /s '\n" + + " - 'igfxCPL.cpl'\n" + + " condition: selection1 and not filter and not fp1_igfx", false)); + Assert.assertEquals("((CommandLine: *.cpl) AND ((((NOT CommandLine: *\\\\System32\\\\* AND _exists_: CommandLine) AND " + + "(NOT CommandLine: *%System%* AND _exists_: CommandLine))))) AND ((((NOT CommandLine: *regsvr32_ws_* AND _exists_: CommandLine) OR " + + "(NOT CommandLine: *_ws_\\/s_ws_* AND _exists_: CommandLine) OR (NOT CommandLine: *igfxCPL.cpl* AND _exists_: CommandLine))))", queries.get(0).toString()); + } + + public void testConvertNotWithAnd() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " selection:\n" + + " EventType: SetValue\n" + + " TargetObject|endswith: '\\Software\\Microsoft\\WAB\\DLLPath'\n" + + " filter:\n" + + " Details: '%CommonProgramFiles%\\System\\wab32.dll'\n" + + " condition: selection and not filter", false)); + Assert.assertEquals("((EventType: \"SetValue\") AND (TargetObject: *\\\\Software\\\\Microsoft\\\\WAB\\\\DLLPath)) AND ((NOT Details: \"%CommonProgramFiles%\\\\System\\\\wab32.dll\" AND _exists_: Details))", queries.get(0).toString()); + } + + public void testConvertNotWithOrAndList() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel1:\n" + + " field1: valueA1\n" + + " field2: valueA2\n" + + " field3: valueA3\n" + + " sel3:\n" + + " - resp_mime_types|contains: 'dosexec'\n" + + " - c-uri|endswith: '.exe'\n" + + " condition: not sel1 or sel3", false)); + Assert.assertEquals("((((NOT field1: \"valueA1\" AND _exists_: field1) OR (NOT field2: \"valueA2\" AND _exists_: field2) OR (NOT field3: \"valueA3\" AND _exists_: field3)))) OR ((resp_mime_types: *dosexec*) OR (c-uri: *.exe))", queries.get(0).toString()); + } + + public void testConvertNotWithNumAndBool() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel1:\n" + + " field1: 1\n" + + " sel2:\n" + + " field2: true\n" + + " condition: not sel1 and not sel2", false)); + Assert.assertEquals("((NOT field1: 1 AND _exists_: field1)) AND ((NOT field2: true AND _exists_: field2))", queries.get(0).toString()); + } + + public void testConvertNotWithNull() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel1:\n" + + " fieldA: null\n" + + " sel2:\n" + + " fieldB: true\n" + + " condition: not sel1", false)); + Assert.assertEquals("(NOT fieldA: (NOT [* TO *]) AND _exists_: fieldA)", queries.get(0).toString()); + } + + public void testConvertNotWithKeywords() throws IOException, SigmaError { + OSQueryBackend queryBackend = testBackend(); + List queries = queryBackend.convertRule(SigmaRule.fromYaml( + " title: Test\n" + + " id: 39f919f3-980b-4e6f-a975-8af7e507ef2b\n" + + " status: test\n" + + " level: critical\n" + + " description: Detects QuarksPwDump clearing access history in hive\n" + + " author: Florian Roth\n" + + " date: 2017/05/15\n" + + " logsource:\n" + + " category: test_category\n" + + " product: test_product\n" + + " detection:\n" + + " sel1:\n" + + " fieldA: value1\n" + + " sel2:\n" + + " fieldB: value2\n" + + " keywords:\n" + + " - test1\n" + + " - 123\n" + + " condition: not keywords", false)); + Assert.assertEquals("(((NOT \"test1\") AND (NOT \"123\")))", queries.get(0).toString()); } public void testConvertPrecedence() throws IOException, SigmaError { @@ -740,7 +910,7 @@ public void testConvertPrecedence() throws IOException, SigmaError { " sel4:\n" + " fieldD: value5\n" + " condition: (sel1 or sel2) and not (sel3 and sel4)", false)); - Assert.assertEquals("((fieldA: \"value1\") OR (mappedB: \"value2\")) AND ((NOT ((fieldC: \"value4\") AND (fieldD: \"value5\"))))", queries.get(0).toString()); + Assert.assertEquals("((fieldA: \"value1\") OR (mappedB: \"value2\")) AND ((((NOT fieldC: \"value4\" AND _exists_: fieldC) OR (NOT fieldD: \"value5\" AND _exists_: fieldD))))", queries.get(0).toString()); } public void testConvertMultiConditions() throws IOException, SigmaError { From 75c442902f2715e4b41f96e77230ac773d0eb2d2 Mon Sep 17 00:00:00 2001 From: Megha Goyal <56077967+goyamegh@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:32:32 -0800 Subject: [PATCH 12/12] Add search request timeouts for correlations workflows (#893) * Reinstating more leaks plugged-in for correlations workflows Signed-off-by: Megha Goyal * Add search timeouts to all correlation searches Signed-off-by: Megha Goyal * Fix logging and exception messages Signed-off-by: Megha Goyal * Change search timeout to 30 seconds Signed-off-by: Megha Goyal --------- Signed-off-by: Megha Goyal --- .../correlation/JoinEngine.java | 7 ++ .../correlation/VectorEmbeddingsEngine.java | 21 ++++- .../TransportCorrelateFindingAction.java | 89 +++++++++---------- 3 files changed, 69 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index b33c4d43b..3b4314e12 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.join.ScoreMode; import org.opensearch.OpenSearchStatusException; import org.opensearch.cluster.routing.Preference; +import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.alerting.model.DocLevelQuery; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.MultiSearchRequest; @@ -132,6 +133,7 @@ private void generateAutoCorrelations(Detector detector, Finding finding) throws searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName)); searchRequest.source(sourceBuilder); searchRequest.preference(Preference.PRIMARY_FIRST.type()); + searchRequest.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); mSearchRequest.add(searchRequest); } @@ -214,6 +216,7 @@ private void onAutoCorrelations(Detector detector, Finding finding, Map { if (response.isTimedOut()) { @@ -277,6 +280,7 @@ private void getValidDocuments(String detectorType, List indices, List searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(relatedDocIds.getKey())); searchRequest.source(searchSourceBuilder); searchRequest.preference(Preference.PRIMARY_FIRST.type()); + searchRequest.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); categories.add(relatedDocIds.getKey()); mSearchRequest.add(searchRequest); diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java index 86fc70bbd..78f7dc765 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/VectorEmbeddingsEngine.java @@ -32,6 +32,7 @@ import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction; import org.opensearch.securityanalytics.util.CorrelationIndices; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -94,6 +95,7 @@ public void insertCorrelatedFindings(String detectorType, Finding finding, Strin request.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); request.source(searchSourceBuilder); request.preference(Preference.PRIMARY_FIRST.type()); + request.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); mSearchRequest.add(request); } @@ -195,6 +197,12 @@ public void insertCorrelatedFindings(String detectorType, Finding finding, Strin } public void insertOrphanFindings(String detectorType, Finding finding, float timestampFeature, Map logTypes) { + if (logTypes.get(detectorType) == null ) { + log.debug("Missing detector type {} in the log types index for finding id {}. Keys in the index: {}", + detectorType, finding.getId(), Arrays.toString(logTypes.keySet().toArray())); + onFailure(new OpenSearchStatusException("insertOrphanFindings null log types for detector type: " + detectorType, RestStatus.INTERNAL_SERVER_ERROR)); + } + SearchRequest searchRequest = getSearchMetadataIndexRequest(detectorType, finding, logTypes); Map tags = logTypes.get(detectorType).getTags(); String correlationId = tags.get("correlation_id").toString(); @@ -251,7 +259,8 @@ public void insertOrphanFindings(String detectorType, Finding finding, float tim onFailure(ex); } } else { - onFailure(new OpenSearchStatusException(indexResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + onFailure(new OpenSearchStatusException("Indexing failed with response {} ", + indexResponse.status(), indexResponse.toString())); } }, this::onFailure)); } else { @@ -297,7 +306,8 @@ public void insertOrphanFindings(String detectorType, Finding finding, float tim onFailure(ex); } } else { - onFailure(new OpenSearchStatusException(indexResponse.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + onFailure(new OpenSearchStatusException("Indexing failed with response {} ", + indexResponse.status(), indexResponse.toString())); } }, this::onFailure)); } else { @@ -323,6 +333,7 @@ public void insertOrphanFindings(String detectorType, Finding finding, float tim request.indices(CorrelationIndices.CORRELATION_HISTORY_INDEX_PATTERN_REGEXP); request.source(searchSourceBuilder); request.preference(Preference.PRIMARY_FIRST.type()); + request.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); client.search(request, ActionListener.wrap(searchResponse -> { if (searchResponse.isTimedOut()) { @@ -407,6 +418,9 @@ public void insertOrphanFindings(String detectorType, Finding finding, float tim } catch (Exception ex) { onFailure(ex); } + } else { + onFailure(new OpenSearchStatusException("Indexing failed with response {} ", + indexResponse.status(), indexResponse.toString())); } }, this::onFailure)); } catch (Exception ex) { @@ -432,7 +446,7 @@ private void indexCorrelatedFindings(XContentBuilder builder) { if (response.status().equals(RestStatus.CREATED)) { correlateFindingAction.onOperation(); } else { - onFailure(new OpenSearchStatusException(response.toString(), RestStatus.INTERNAL_SERVER_ERROR)); + onFailure(new OpenSearchStatusException("Indexing failed with response {} ", response.status(), response.toString())); } }, this::onFailure)); } @@ -454,6 +468,7 @@ private SearchRequest getSearchMetadataIndexRequest(String detectorType, Finding searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); searchRequest.source(searchSourceBuilder); searchRequest.preference(Preference.PRIMARY_FIRST.type()); + searchRequest.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); return searchRequest; } diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index d5e0eed32..910794556 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -172,13 +172,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (response.isTimedOut()) { @@ -245,8 +246,8 @@ void start() { ); Detector detector = Detector.docParse(xcp, hit.getId(), hit.getVersion()); joinEngine.onSearchDetectorResponse(detector, finding); - } catch (IOException e) { - log.error("IOException for request {}", searchRequest.toString(), e); + } catch (Exception e) { + log.error("Exception for request {}", searchRequest.toString(), e); onFailures(e); } } else { @@ -277,7 +278,7 @@ public void initCorrelationIndex(String detectorType, Map> } else { getTimestampFeature(detectorType, correlatedFindings, null, correlationRules); } - } catch (IOException ex) { + } catch (Exception ex) { onFailures(ex); } } @@ -353,7 +354,8 @@ public void getTimestampFeature(String detectorType, Map> c }, this::onFailures)); }, this::onFailures)); } else { - log.error(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); + Exception e = new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR); + onFailures(e); } }, this::onFailures)); } else { @@ -364,54 +366,49 @@ public void getTimestampFeature(String detectorType, Map> c if (response.getHits().getHits().length == 0) { onFailures(new ResourceNotFoundException( "Failed to find hits in metadata index for finding id {}", request.getFinding().getId())); - } - - String id = response.getHits().getHits()[0].getId(); - Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); - long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); + } else { + String id = response.getHits().getHits()[0].getId(); + Map hitSource = response.getHits().getHits()[0].getSourceAsMap(); + long scoreTimestamp = (long) hitSource.get("scoreTimestamp"); - long newScoreTimestamp = findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL; - if (newScoreTimestamp > scoreTimestamp) { - try { + long newScoreTimestamp = findingTimestamp - CorrelationIndices.FIXED_HISTORICAL_INTERVAL; + if (newScoreTimestamp > scoreTimestamp) { IndexRequest scoreIndexRequest = getCorrelationMetadataIndexRequest(id, newScoreTimestamp); client.index(scoreIndexRequest, ActionListener.wrap(indexResponse -> { - SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); client.search(searchRequest, ActionListener.wrap(searchResponse -> { - if (searchResponse.isTimedOut()) { - onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); - } + if (searchResponse.isTimedOut()) { + onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); + } - SearchHit[] hits = searchResponse.getHits().getHits(); - Map logTypes = new HashMap<>(); - for (SearchHit hit : hits) { - Map sourceMap = hit.getSourceAsMap(); - logTypes.put(sourceMap.get("name").toString(), - new CustomLogType(sourceMap)); - } + SearchHit[] hits = searchResponse.getHits().getHits(); + Map logTypes = new HashMap<>(); + for (SearchHit hit : hits) { + Map sourceMap = hit.getSourceAsMap(); + logTypes.put(sourceMap.get("name").toString(), new CustomLogType(sourceMap)); + } - if (correlatedFindings != null) { - if (correlatedFindings.isEmpty()) { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); - } - for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { - vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), - Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); - } - } else { - vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + if (correlatedFindings != null) { + if (correlatedFindings.isEmpty()) { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, request.getFinding(), Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); } - }, this::onFailures)); + for (Map.Entry> correlatedFinding : correlatedFindings.entrySet()) { + vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(), + Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules, logTypes); + } + } else { + vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), logTypes); + } + }, this::onFailures)); }, this::onFailures)); - } catch (Exception ex) { - onFailures(ex); - } - } else { - float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); + } else { + float timestampFeature = Long.valueOf((findingTimestamp - scoreTimestamp) / 1000L).floatValue(); - SearchRequest searchRequest = getSearchLogTypeIndexRequest(); - insertFindings(timestampFeature, searchRequest, correlatedFindings, detectorType, correlationRules, orphanFinding); + SearchRequest searchRequest = getSearchLogTypeIndexRequest(); + insertFindings(timestampFeature, searchRequest, correlatedFindings, detectorType, correlationRules, orphanFinding); + } } }, this::onFailures)); } @@ -430,6 +427,7 @@ private SearchRequest getSearchLogTypeIndexRequest() { SearchRequest searchRequest = new SearchRequest(); searchRequest.indices(LogTypeService.LOG_TYPE_INDEX); searchRequest.source(searchSourceBuilder); + searchRequest.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); return searchRequest; } @@ -439,13 +437,13 @@ private IndexRequest getCorrelationMetadataIndexRequest(String id, long newScore scoreBuilder.field("root", false); scoreBuilder.endObject(); - IndexRequest scoreIndexRequest = new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) + return new IndexRequest(CorrelationIndices.CORRELATION_METADATA_INDEX) .id(id) .source(scoreBuilder) .timeout(indexTimeout) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - return scoreIndexRequest; } + private void insertFindings(float timestampFeature, SearchRequest searchRequest, Map> correlatedFindings, String detectorType, List correlationRules, Finding orphanFinding) { client.search(searchRequest, ActionListener.wrap(response -> { if (response.isTimedOut()) { @@ -485,6 +483,7 @@ private SearchRequest getSearchMetadataIndexRequest() { searchRequest.indices(CorrelationIndices.CORRELATION_METADATA_INDEX); searchRequest.source(searchSourceBuilder); searchRequest.preference(Preference.PRIMARY_FIRST.type()); + searchRequest.setCancelAfterTimeInterval(TimeValue.timeValueSeconds(30L)); return searchRequest; }