diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index f97afcb60..2f5f97c50 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -57,10 +57,12 @@ import org.opensearch.securityanalytics.action.AckAlertsAction; import org.opensearch.securityanalytics.action.CreateIndexMappingsAction; import org.opensearch.securityanalytics.action.CorrelatedFindingAction; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsAction; import org.opensearch.securityanalytics.action.DeleteCustomLogTypeAction; import org.opensearch.securityanalytics.action.DeleteDetectorAction; import org.opensearch.securityanalytics.action.DeleteRuleAction; import org.opensearch.securityanalytics.action.GetAllRuleCategoriesAction; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsAction; import org.opensearch.securityanalytics.action.GetDetectorAction; import org.opensearch.securityanalytics.action.GetFindingsAction; import org.opensearch.securityanalytics.action.GetIndexMappingsAction; @@ -131,6 +133,8 @@ public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, Map public static final String CORRELATION_RULES_BASE_URI = PLUGINS_BASE_URI + "/correlation/rules"; public static final String CUSTOM_LOG_TYPE_URI = PLUGINS_BASE_URI + "/logtype"; + + public static final String CORRELATIONS_ALERTS_BASE_URI = PLUGINS_BASE_URI + "/correlationAlerts"; public static final String JOB_INDEX_NAME = ".opensearch-sap--job"; public static final Map TIF_JOB_INDEX_SETTING = Map.of(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1, IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-all", IndexMetadata.SETTING_INDEX_HIDDEN, true); @@ -238,7 +242,9 @@ public List getRestHandlers(Settings settings, new RestSearchCorrelationRuleAction(), new RestIndexCustomLogTypeAction(), new RestSearchCustomLogTypeAction(), - new RestDeleteCustomLogTypeAction() + new RestDeleteCustomLogTypeAction(), + new RestGetCorrelationsAlertsAction(), + new RestAcknowledgeCorrelationAlertsAction() ); } @@ -359,7 +365,9 @@ public List> getSettings() { new ActionHandler<>(IndexCustomLogTypeAction.INSTANCE, TransportIndexCustomLogTypeAction.class), new ActionHandler<>(SearchCustomLogTypeAction.INSTANCE, TransportSearchCustomLogTypeAction.class), new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class), - new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class) + new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class), + new ActionPlugin.ActionHandler<>(GetCorrelationAlertsAction.INSTANCE, TransportGetCorrelationAlertsAction.class), + new ActionPlugin.ActionHandler<>(AckCorrelationAlertsAction.INSTANCE, TransportAckCorrelationAlertsAction.class) ); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsAction.java new file mode 100644 index 000000000..d85d23f1c --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionType; + +/** + * Acknowledge Correlation Alert Action + */ +public class AckCorrelationAlertsAction extends ActionType { + public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlationAlerts/ack"; + public static final AckCorrelationAlertsAction INSTANCE = new AckCorrelationAlertsAction(); + + public AckCorrelationAlertsAction() { + super(NAME, AckCorrelationAlertsResponse::new); + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsRequest.java new file mode 100644 index 000000000..2183b4658 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsRequest.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ValidateActions; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class AckCorrelationAlertsRequest extends ActionRequest { + private final List correlationAlertIds; + + public AckCorrelationAlertsRequest(List correlationAlertIds) { + this.correlationAlertIds = correlationAlertIds; + } + + public AckCorrelationAlertsRequest(StreamInput in) throws IOException { + correlationAlertIds = Collections.unmodifiableList(in.readStringList()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if(correlationAlertIds == null || correlationAlertIds.isEmpty()) { + validationException = ValidateActions.addValidationError("alert ids list cannot be empty", validationException); + } + return validationException; + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(this.correlationAlertIds); + } + + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.startObject() + .field("correlation_alert_ids", correlationAlertIds) + .endObject(); + } + + public static AckAlertsRequest readFrom(StreamInput sin) throws IOException { + return new AckAlertsRequest(sin); + } + + public List getCorrelationAlertIds() { + return correlationAlertIds; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsResponse.java b/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsResponse.java new file mode 100644 index 000000000..a34ae6b74 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/AckCorrelationAlertsResponse.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class AckCorrelationAlertsResponse extends ActionResponse implements ToXContentObject { + + private final List acknowledged; + private final List failed; + + public AckCorrelationAlertsResponse(List acknowledged, List failed) { + this.acknowledged = acknowledged; + this.failed = failed; + } + + public AckCorrelationAlertsResponse(StreamInput sin) throws IOException { + this( + Collections.unmodifiableList(sin.readList(CorrelationAlert::new)), + Collections.unmodifiableList(sin.readList(CorrelationAlert::new)) + ); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeList(this.acknowledged); + streamOutput.writeList(this.failed); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject() + .field("acknowledged",this.acknowledged) + .field("failed",this.failed); + return builder.endObject(); + } + + public List getAcknowledged() { + return acknowledged; + } + + public List getFailed() { + return failed; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsAction.java new file mode 100644 index 000000000..d07fc7bfc --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionType; + +public class GetCorrelationAlertsAction extends ActionType { + + public static final GetCorrelationAlertsAction INSTANCE = new GetCorrelationAlertsAction(); + public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlationAlerts/get"; + + public GetCorrelationAlertsAction() { + super(NAME, GetCorrelationAlertsResponse::new); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java new file mode 100644 index 000000000..41213b3cd --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.commons.alerting.model.Table; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.time.Instant; +import java.util.Locale; + +import static org.opensearch.action.ValidateActions.addValidationError; + +public class GetCorrelationAlertsRequest extends ActionRequest { + private String correlationRuleId; + private String correlationRuleName; + private Table table; + private String severityLevel; + private String alertState; + + private Instant startTime; + + private Instant endTime; + + public static final String CORRELATION_RULE_ID = "correlation_rule_id"; + + public GetCorrelationAlertsRequest( + String correlationRuleId, + String correlationRuleName, + Table table, + String severityLevel, + String alertState, + Instant startTime, + Instant endTime + ) { + super(); + this.correlationRuleId = correlationRuleId; + this.correlationRuleName = correlationRuleName; + this.table = table; + this.severityLevel = severityLevel; + this.alertState = alertState; + this.startTime = startTime; + this.endTime = endTime; + } + public GetCorrelationAlertsRequest(StreamInput sin) throws IOException { + this( + sin.readOptionalString(), + sin.readOptionalString(), + Table.readFrom(sin), + sin.readString(), + sin.readString(), + sin.readOptionalInstant(), + sin.readOptionalInstant() + ); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if ((correlationRuleId != null && correlationRuleId.isEmpty())) { + validationException = addValidationError(String.format(Locale.getDefault(), + "Correlation ruleId is empty or not valid", CORRELATION_RULE_ID), + validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(correlationRuleId); + out.writeOptionalString(correlationRuleName); + table.writeTo(out); + out.writeString(severityLevel); + out.writeString(alertState); + out.writeOptionalInstant(startTime); + out.writeOptionalInstant(endTime); + } + + public String getCorrelationRuleId() { + return correlationRuleId; + } + + public Table getTable() { + return table; + } + + public String getSeverityLevel() { + return severityLevel; + } + + public String getAlertState() { + return alertState; + } + + public String getCorrelationRuleName() { + return correlationRuleName; + } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java new file mode 100644 index 000000000..33ffc1e93 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class GetCorrelationAlertsResponse extends ActionResponse implements ToXContentObject { + + private static final Logger log = LogManager.getLogger(GetCorrelationAlertsResponse.class); + private static final String CORRELATION_ALERTS_FIELD = "correlationAlerts"; + private static final String TOTAL_ALERTS_FIELD = "total_alerts"; + + private List alerts; + private Integer totalAlerts; + + public GetCorrelationAlertsResponse(List alerts, Integer totalAlerts) { + super(); + this.alerts = alerts; + this.totalAlerts = totalAlerts; + } + + public GetCorrelationAlertsResponse(StreamInput sin) throws IOException { + this( + Collections.unmodifiableList(sin.readList(CorrelationAlert::new)), + sin.readInt() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(this.alerts); + out.writeInt(this.totalAlerts); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject() + .field(CORRELATION_ALERTS_FIELD, this.alerts) + .field(TOTAL_ALERTS_FIELD, this.totalAlerts); + return builder.endObject(); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java index f7aeb4e4d..54c09d29a 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -6,15 +6,18 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; -import org.opensearch.commons.alerting.model.ActionExecutionResult; import org.opensearch.commons.alerting.model.Alert; -import org.opensearch.commons.authuser.User; +import org.opensearch.commons.alerting.model.Table; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; @@ -23,18 +26,26 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsResponse; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsResponse; import org.opensearch.securityanalytics.util.CorrelationIndices; import java.io.IOException; import java.time.Instant; import java.util.List; import java.util.ArrayList; import java.util.Collections; +import java.util.Map; public class CorrelationAlertService { private static final Logger log = LogManager.getLogger(CorrelationAlertService.class); @@ -70,7 +81,6 @@ public CorrelationAlertService(Client client, NamedXContentRegistry xContentRegi * * @param ruleId The correlation rule ID to filter the alerts * @param currentTime The current time of the search range - * @return The search response containing active alerts */ public void getActiveAlerts(String ruleId, long currentTime, ActionListener listener) { Instant currentTimeDate = Instant.ofEpochMilli(currentTime); @@ -139,119 +149,177 @@ public void indexCorrelationAlert(CorrelationAlert correlationAlert, TimeValue i } } + public void getCorrelationAlerts(String ruleId, Table tableProp, ActionListener listener) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + if (ruleId != null) { + queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("correlation_rule_id", ruleId)); + } + + FieldSortBuilder sortBuilder = SortBuilders + .fieldSort(tableProp.getSortString()) + .order(SortOrder.fromString(tableProp.getSortOrder())); + if (tableProp.getMissing() != null && !tableProp.getMissing().isEmpty()) { + sortBuilder.missing(tableProp.getMissing()); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .version(true) + .seqNoAndPrimaryTerm(true) + .query(queryBuilder) + .sort(sortBuilder) + .size(tableProp.getSize()) + .from(tableProp.getStartIndex()); + + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getTotalHits().equals(0)) { + listener.onResponse(new GetCorrelationAlertsResponse(Collections.emptyList(), 0)); + } else { + listener.onResponse(new GetCorrelationAlertsResponse( + parseCorrelationAlerts(searchResponse), + searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ? + (int) searchResponse.getHits().getTotalHits().value : 0) + ); + } + }, + e -> { + log.error("Search request to fetch correlation alerts failed", e); + listener.onFailure(e); + } + )); + } + + public void acknowledgeAlerts(List alertIds, ActionListener listener) { + BulkRequest bulkRequest = new BulkRequest(); + List acknowledgedAlerts = new ArrayList<>(); + List failedAlerts = new ArrayList<>(); + + TermsQueryBuilder termsQueryBuilder = QueryBuilders.termsQuery("id", alertIds); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(termsQueryBuilder); + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + // Execute the search request + client.search(searchRequest, new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + // Iterate through the search hits + for (SearchHit hit : searchResponse.getHits().getHits()) { + // Construct a script to update the document with the new state and acknowledgedTime + // Construct a script to update the document with the new state and acknowledgedTime + Script script = new Script(ScriptType.INLINE, "painless", + "ctx._source.state = params.state; ctx._source.acknowledged_time = params.time", + Map.of("state", Alert.State.ACKNOWLEDGED, "time", Instant.now())); + // Create an update request with the script + UpdateRequest updateRequest = new UpdateRequest(CorrelationIndices.CORRELATION_ALERT_INDEX, hit.getId()) + .script(script); + + // Add the update request to the bulk request + bulkRequest.add(updateRequest); + + // Add the current alert to the acknowledged alerts list + try { + acknowledgedAlerts.add(getParsedCorrelationAlert(hit)); + } catch (IOException e) { + log.error("Exception while acknowledging alerts: {}", e.toString()); + } + } + + // Check if there are any update requests in the bulk request + if (!bulkRequest.requests().isEmpty()) { + // Execute the bulk request asynchronously + client.bulk(bulkRequest, new ActionListener() { + @Override + public void onResponse(BulkResponse bulkResponse) { + // Iterate through the bulk response to identify failed updates + for (BulkItemResponse itemResponse : bulkResponse.getItems()) { + if (itemResponse.isFailed()) { + // If an update failed, add the corresponding alert to the failed alerts list + failedAlerts.add(acknowledgedAlerts.get(itemResponse.getItemId())); + } + } + // Create and pass the CorrelationAckAlertsResponse to the listener + listener.onResponse(new AckCorrelationAlertsResponse(acknowledgedAlerts, failedAlerts)); + } + + @Override + public void onFailure(Exception e) { + // Handle failure + listener.onFailure(e); + } + }); + } else { + // If there are no update requests, return an empty response + listener.onResponse(new AckCorrelationAlertsResponse(acknowledgedAlerts, failedAlerts)); + } + } + + @Override + public void onFailure(Exception e) { + // Handle failure + listener.onFailure(e); + } + }); + } + + public void updateCorrelationAlertsWithError(String correlationRuleId) { + BulkRequest bulkRequest = new BulkRequest(); + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("correlation_rule_id", correlationRuleId)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder); + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + // Execute the search request + client.search(searchRequest, new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + // Iterate through the search hits + for (SearchHit hit : searchResponse.getHits().getHits()) { + // Construct a script to update the document with the new state and error_message + Script script = new Script(ScriptType.INLINE, "painless", + "ctx._source.state = params.state; ctx._source.error_message = params.error_message", + Map.of("state", Alert.State.ERROR, "error_message", "The rule associated to this Alert is deleted")); + // Create an update request with the script + UpdateRequest updateRequest = new UpdateRequest(CorrelationIndices.CORRELATION_ALERT_INDEX, hit.getId()) + .script(script); + // Add the update request to the bulk request + bulkRequest.add(updateRequest); + client.bulk(bulkRequest); + } + } + @Override + public void onFailure(Exception e) { + log.error("Error updating the alerts with Error message for correlation ruleId: {}", correlationRuleId); + } + }); + } + + public List parseCorrelationAlerts(final SearchResponse response) throws IOException { List alerts = new ArrayList<>(); for (SearchHit hit : response.getHits()) { - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - hit.getSourceAsString() - ); - xcp.nextToken(); - CorrelationAlert correlationAlert = parse(xcp, hit.getId(), hit.getVersion()); + CorrelationAlert correlationAlert = getParsedCorrelationAlert(hit); alerts.add(correlationAlert); } return alerts; } - // logic will be moved to common-utils, once the parsing logic in common-utils is fixed - public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException { - // Parse additional CorrelationAlert-specific fields - List correlatedFindingIds = new ArrayList<>(); - String correlationRuleId = null; - String correlationRuleName = null; - User user = null; - int schemaVersion = 0; - String triggerName = null; - Alert.State state = null; - String errorMessage = null; - String severity = null; - List actionExecutionResults = new ArrayList<>(); - Instant startTime = null; - Instant endTime = null; - Instant acknowledgedTime = null; - - while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = xcp.currentName(); - xcp.nextToken(); - switch (fieldName) { - case CORRELATED_FINDING_IDS: - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { - correlatedFindingIds.add(xcp.text()); - } - break; - case CORRELATION_RULE_ID: - correlationRuleId = xcp.text(); - break; - case CORRELATION_RULE_NAME: - correlationRuleName = xcp.text(); - break; - case USER_FIELD: - user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp); - break; - case ALERT_ID_FIELD: - id = xcp.text(); - break; - case ALERT_VERSION_FIELD: - version = xcp.longValue(); - break; - case SCHEMA_VERSION_FIELD: - schemaVersion = xcp.intValue(); - break; - case TRIGGER_NAME_FIELD: - triggerName = xcp.text(); - break; - case STATE_FIELD: - state = Alert.State.valueOf(xcp.text()); - break; - case ERROR_MESSAGE_FIELD: - errorMessage = xcp.textOrNull(); - break; - case SEVERITY_FIELD: - severity = xcp.text(); - break; - case ACTION_EXECUTION_RESULTS_FIELD: - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { - actionExecutionResults.add(ActionExecutionResult.parse(xcp)); - } - break; - case START_TIME_FIELD: - startTime = Instant.parse(xcp.text()); - break; - case END_TIME_FIELD: - endTime = Instant.parse(xcp.text()); - break; - case ACKNOWLEDGED_TIME_FIELD: - if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { - acknowledgedTime = null; - } else { - acknowledgedTime = Instant.parse(xcp.text()); - } - break; - } - } - - // Create and return CorrelationAlert object - return new CorrelationAlert( - correlatedFindingIds, - correlationRuleId, - correlationRuleName, - id, - version, - schemaVersion, - user, - triggerName, - state, - startTime, - endTime, - acknowledgedTime, - errorMessage, - severity, - actionExecutionResults - ); + private CorrelationAlert getParsedCorrelationAlert(SearchHit hit) throws IOException { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString() + ); + xcp.nextToken(); + CorrelationAlert correlationAlert = CorrelationAlertsList.parse(xcp, hit.getId(), hit.getVersion()); + return correlationAlert; } + } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java index a6cdda9a6..2770f3eaa 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java @@ -4,8 +4,17 @@ */ package org.opensearch.securityanalytics.correlation.alert; + +import org.opensearch.commons.alerting.model.ActionExecutionResult; +import org.opensearch.commons.alerting.model.Alert; import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; import java.util.List; /** @@ -22,6 +31,106 @@ public CorrelationAlertsList(List correlationAlertList, Intege this.totalAlerts = totalAlerts; } + // logic will be moved to common-utils, once the parsing logic in common-utils is fixed + public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException { + // Parse additional CorrelationAlert-specific fields + List correlatedFindingIds = new ArrayList<>(); + String correlationRuleId = null; + String correlationRuleName = null; + User user = null; + int schemaVersion = 0; + String triggerName = null; + Alert.State state = null; + String errorMessage = null; + String severity = null; + List actionExecutionResults = new ArrayList<>(); + Instant startTime = null; + Instant endTime = null; + Instant acknowledgedTime = null; + + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + switch (fieldName) { + case CorrelationAlertService.CORRELATED_FINDING_IDS: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + correlatedFindingIds.add(xcp.text()); + } + break; + case CorrelationAlertService.CORRELATION_RULE_ID: + correlationRuleId = xcp.text(); + break; + case CorrelationAlertService.CORRELATION_RULE_NAME: + correlationRuleName = xcp.text(); + break; + case CorrelationAlertService.USER_FIELD: + user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp); + break; + case CorrelationAlertService.ALERT_ID_FIELD: + id = xcp.text(); + break; + case CorrelationAlertService.ALERT_VERSION_FIELD: + version = xcp.longValue(); + break; + case CorrelationAlertService.SCHEMA_VERSION_FIELD: + schemaVersion = xcp.intValue(); + break; + case CorrelationAlertService.TRIGGER_NAME_FIELD: + triggerName = xcp.text(); + break; + case CorrelationAlertService.STATE_FIELD: + state = Alert.State.valueOf(xcp.text()); + break; + case CorrelationAlertService.ERROR_MESSAGE_FIELD: + errorMessage = xcp.textOrNull(); + break; + case CorrelationAlertService.SEVERITY_FIELD: + severity = xcp.text(); + break; + case CorrelationAlertService.ACTION_EXECUTION_RESULTS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + actionExecutionResults.add(ActionExecutionResult.parse(xcp)); + } + break; + case CorrelationAlertService.START_TIME_FIELD: + startTime = Instant.parse(xcp.text()); + break; + case CorrelationAlertService.END_TIME_FIELD: + endTime = Instant.parse(xcp.text()); + break; + case CorrelationAlertService.ACKNOWLEDGED_TIME_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + acknowledgedTime = null; + } else { + acknowledgedTime = Instant.parse(xcp.text()); + } + break; + } + } + + // Create and return CorrelationAlert object + return new CorrelationAlert( + correlatedFindingIds, + correlationRuleId, + correlationRuleName, + id, + version, + schemaVersion, + user, + triggerName, + state, + startTime, + endTime, + acknowledgedTime, + errorMessage, + severity, + actionExecutionResults + ); + } + + public List getCorrelationAlertList() { return correlationAlertList; } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java index ba42e252b..10e61857b 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java @@ -1,3 +1,7 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ package org.opensearch.securityanalytics.correlation.alert; import org.apache.logging.log4j.LogManager; diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java index 90b8ded25..148da9a50 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java @@ -1,7 +1,9 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ package org.opensearch.securityanalytics.correlation.alert.notifications; -import org.opensearch.securityanalytics.model.CorrelationRule; - import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java index 34478a063..be141a1d9 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java @@ -1,3 +1,7 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ package org.opensearch.securityanalytics.correlation.alert.notifications; import org.apache.logging.log4j.LogManager; diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestAcknowledgeCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestAcknowledgeCorrelationAlertsAction.java new file mode 100644 index 000000000..1a1eede54 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestAcknowledgeCorrelationAlertsAction.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.resthandler; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsAction; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +/** + * Acknowledge list of correlation alerts generated by correlation rules. + */ +public class RestAcknowledgeCorrelationAlertsAction extends BaseRestHandler { + @Override + public String getName() { + return "ack_correlation_alerts_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route(RestRequest.Method.POST, String.format( + Locale.getDefault(), + "%s/_acknowledge/correlationAlerts", + SecurityAnalyticsPlugin.PLUGINS_BASE_URI) + )); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) throws IOException { + List alertIds = getAlertIds(request.contentParser()); + AckCorrelationAlertsRequest CorrelationAckAlertsRequest = new AckCorrelationAlertsRequest(alertIds); + return channel -> nodeClient.execute( + AckCorrelationAlertsAction.INSTANCE, + CorrelationAckAlertsRequest, + new RestToXContentListener<>(channel) + ); + } + + private List getAlertIds(XContentParser xcp) throws IOException { + List ids = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + if (fieldName.equals("alertIds")) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + ids.add(xcp.text()); + } + } + + } + return ids; + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetCorrelationsAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetCorrelationsAlertsAction.java new file mode 100644 index 000000000..471c26915 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetCorrelationsAlertsAction.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.resthandler; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.model.Table; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsAction; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsRequest; + +import java.io.IOException; +import java.time.DateTimeException; +import java.time.Instant; +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.opensearch.rest.RestRequest.Method.GET; + +public class RestGetCorrelationsAlertsAction extends BaseRestHandler { + + @Override + public String getName() { + return "get_correlation_alerts_action_sa"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + String correlationRuleId = request.param("correlation_rule_id", null); + String correlationRuleName = request.param("correlation_rule_name", null); + String severityLevel = request.param("severityLevel", "ALL"); + String alertState = request.param("alertState", "ALL"); + // Table params + String sortString = request.param("sortString", "start_time"); + String sortOrder = request.param("sortOrder", "asc"); + String missing = request.param("missing"); + int size = request.paramAsInt("size", 20); + int startIndex = request.paramAsInt("startIndex", 0); + String searchString = request.param("searchString", ""); + + Instant startTime = null; + String startTimeParam = request.param("startTime"); + if (startTimeParam != null && !startTimeParam.isEmpty()) { + try { + startTime = Instant.ofEpochMilli(Long.parseLong(startTimeParam)); + } catch (NumberFormatException | NullPointerException | DateTimeException e) { + startTime = Instant.now(); + } + } + + Instant endTime = null; + String endTimeParam = request.param("endTime"); + if (endTimeParam != null && !endTimeParam.isEmpty()) { + try { + endTime = Instant.ofEpochMilli(Long.parseLong(endTimeParam)); + } catch (NumberFormatException | NullPointerException | DateTimeException e) { + endTime = Instant.now(); + } + } + + Table table = new Table( + sortOrder, + sortString, + missing, + size, + startIndex, + searchString + ); + + GetCorrelationAlertsRequest req = new GetCorrelationAlertsRequest( + correlationRuleId, + correlationRuleName, + table, + severityLevel, + alertState, + startTime, + endTime + ); + + return channel -> client.execute( + GetCorrelationAlertsAction.INSTANCE, + req, + new RestToXContentListener<>(channel) + ); + } + + @Override + public List routes() { + return singletonList(new Route(GET, SecurityAnalyticsPlugin.CORRELATIONS_ALERTS_BASE_URI)); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportAckCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportAckCorrelationAlertsAction.java new file mode 100644 index 000000000..917d0349c --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportAckCorrelationAlertsAction.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsAction; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsRequest; +import org.opensearch.securityanalytics.action.AckCorrelationAlertsResponse; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportAckCorrelationAlertsAction extends HandledTransportAction implements SecureTransportAction { + + private final NamedXContentRegistry xContentRegistry; + + private final ClusterService clusterService; + + private final Settings settings; + + private final ThreadPool threadPool; + + private final CorrelationAlertService correlationAlertService; + + private volatile Boolean filterByEnabled; + + private static final Logger log = LogManager.getLogger(TransportGetCorrelationAlertsAction.class); + + + @Inject + public TransportAckCorrelationAlertsAction(TransportService transportService, CorrelationAlertService correlationAlertService, ActionFilters actionFilters, ClusterService clusterService, AckCorrelationAlertsAction correlationAckAlertsAction, ThreadPool threadPool, Settings settings, NamedXContentRegistry xContentRegistry, Client client) { + super(correlationAckAlertsAction.NAME, transportService, actionFilters, AckCorrelationAlertsRequest::new); + this.xContentRegistry = xContentRegistry; + this.correlationAlertService = correlationAlertService; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.settings = settings; + this.filterByEnabled = SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES.get(this.settings); + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled); + } + + @Override + protected void doExecute(Task task, AckCorrelationAlertsRequest request, ActionListener actionListener) { + + User user = readUserFromThreadContext(this.threadPool); + + String validateBackendRoleMessage = validateUserBackendRoles(user, this.filterByEnabled); + if (!"".equals(validateBackendRoleMessage)) { + actionListener.onFailure(new OpenSearchStatusException("Do not have permissions to resource", RestStatus.FORBIDDEN)); + return; + } + + if (!request.getCorrelationAlertIds().isEmpty()) { + correlationAlertService.acknowledgeAlerts( + request.getCorrelationAlertIds(), + actionListener + ); + } + } + + private void setFilterByEnabled(boolean filterByEnabled) { + this.filterByEnabled = filterByEnabled; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteCorrelationRuleAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteCorrelationRuleAction.java index d3c21cf1c..c8f9273e9 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteCorrelationRuleAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportDeleteCorrelationRuleAction.java @@ -8,6 +8,7 @@ package org.opensearch.securityanalytics.transport; +import java.util.Collections; import java.util.Locale; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -19,6 +20,7 @@ import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; @@ -26,6 +28,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.securityanalytics.action.DeleteCorrelationRuleAction; import org.opensearch.securityanalytics.action.DeleteCorrelationRuleRequest; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; import org.opensearch.securityanalytics.model.CorrelationRule; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; import org.opensearch.tasks.Task; @@ -37,14 +40,19 @@ public class TransportDeleteCorrelationRuleAction extends HandledTransportAction private final Client client; + private CorrelationAlertService correlationAlertService; + + @Inject public TransportDeleteCorrelationRuleAction( TransportService transportService, Client client, - ActionFilters actionFilters + ActionFilters actionFilters, + CorrelationAlertService correlationAlertService ) { super(DeleteCorrelationRuleAction.NAME, transportService, actionFilters, DeleteCorrelationRuleRequest::new); this.client = client; + this.correlationAlertService = correlationAlertService; } @Override @@ -72,6 +80,9 @@ public void onResponse(BulkByScrollResponse response) { ); return; } + // update the alerts assosciated with correlation Rules, with error STATE and errorMessage + log.debug("Updating Correlation Alerts with error Message for ruleId: " + correlationRuleId); + correlationAlertService.updateCorrelationAlertsWithError(correlationRuleId); listener.onResponse(new AcknowledgedResponse(true)); } diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java new file mode 100644 index 000000000..cdca86a23 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.securityanalytics.action.*; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportGetCorrelationAlertsAction extends HandledTransportAction implements SecureTransportAction { + + private final NamedXContentRegistry xContentRegistry; + + private final ClusterService clusterService; + + private final Settings settings; + + private final ThreadPool threadPool; + + private final CorrelationAlertService correlationAlertService; + + private volatile Boolean filterByEnabled; + + private static final Logger log = LogManager.getLogger(TransportGetCorrelationAlertsAction.class); + + + @Inject + public TransportGetCorrelationAlertsAction(TransportService transportService, CorrelationAlertService correlationAlertService, ActionFilters actionFilters, ClusterService clusterService, GetCorrelationAlertsAction getCorrelationAlertsAction, ThreadPool threadPool, Settings settings, NamedXContentRegistry xContentRegistry, Client client) { + super(getCorrelationAlertsAction.NAME, transportService, actionFilters, GetCorrelationAlertsRequest::new); + this.xContentRegistry = xContentRegistry; + this.correlationAlertService = correlationAlertService; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.settings = settings; + this.filterByEnabled = SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES.get(this.settings); + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled); + } + + @Override + protected void doExecute(Task task, GetCorrelationAlertsRequest request, ActionListener actionListener) { + + User user = readUserFromThreadContext(this.threadPool); + + String validateBackendRoleMessage = validateUserBackendRoles(user, this.filterByEnabled); + if (!"".equals(validateBackendRoleMessage)) { + actionListener.onFailure(new OpenSearchStatusException("Do not have permissions to resource", RestStatus.FORBIDDEN)); + return; + } + + if (request.getCorrelationRuleId() != null) { + correlationAlertService.getCorrelationAlerts( + request.getCorrelationRuleId(), + request.getTable(), + actionListener + ); + } else { + correlationAlertService.getCorrelationAlerts( + null, + request.getTable(), + actionListener + ); + } + } + + private void setFilterByEnabled(boolean filterByEnabled) { + this.filterByEnabled = filterByEnabled; + } +} \ No newline at end of file diff --git a/src/main/resources/mappings/correlation_alert_mapping.json b/src/main/resources/mappings/correlation_alert_mapping.json index 585a036c6..5245edba8 100644 --- a/src/main/resources/mappings/correlation_alert_mapping.json +++ b/src/main/resources/mappings/correlation_alert_mapping.json @@ -99,4 +99,4 @@ "type": "date" } } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index 1352109b0..ad50a0dde 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -2705,4 +2705,4 @@ public static NamedXContentRegistry xContentRegistry() { public static XContentBuilder builder() throws IOException { return XContentBuilder.builder(XContentType.JSON.xContent()); } -} \ No newline at end of file +}