Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes to add start_time and end_time filters to GetAlertsRequest #1039

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.securityanalytics.action;

import java.io.IOException;
import java.time.Instant;
import java.util.Locale;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
Expand All @@ -24,29 +25,39 @@ public class GetAlertsRequest extends ActionRequest {
private String severityLevel;
private String alertState;

private Instant startTime;

private Instant endTime;

public static final String DETECTOR_ID = "detector_id";

public GetAlertsRequest(
String detectorId,
String logType,
Table table,
String severityLevel,
String alertState
String alertState,
Instant startTime,
Instant endTime
) {
super();
this.detectorId = detectorId;
this.logType = logType;
this.table = table;
this.severityLevel = severityLevel;
this.alertState = alertState;
this.startTime = startTime;
this.endTime = endTime;
}
public GetAlertsRequest(StreamInput sin) throws IOException {
this(
sin.readOptionalString(),
sin.readOptionalString(),
Table.readFrom(sin),
sin.readString(),
sin.readString()
sin.readString(),
sin.readOptionalInstant(),
sin.readOptionalInstant()
);
}

Expand All @@ -68,6 +79,8 @@ public void writeTo(StreamOutput out) throws IOException {
table.writeTo(out);
out.writeString(severityLevel);
out.writeString(alertState);
out.writeOptionalInstant(startTime);
out.writeOptionalInstant(endTime);
}

public String getDetectorId() {
Expand All @@ -89,4 +102,12 @@ public String getAlertState() {
public String getLogType() {
return logType;
}

public Instant getStartTime() {
return startTime;
}

public Instant getEndTime() {
return endTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import org.opensearch.commons.alerting.model.Alert;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.securityanalytics.action.AckAlertsResponse;
import org.opensearch.securityanalytics.action.AlertDto;
import org.opensearch.securityanalytics.action.GetAlertsResponse;
Expand All @@ -29,6 +32,7 @@
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
Expand Down Expand Up @@ -66,6 +70,8 @@ public void getAlertsByDetectorId(
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime,
ActionListener<GetAlertsResponse> listener
) {
this.client.execute(GetDetectorAction.INSTANCE, new GetDetectorRequest(detectorId, -3L), new ActionListener<>() {
Expand All @@ -88,6 +94,8 @@ public void onResponse(GetDetectorResponse getDetectorResponse) {
table,
severityLevel,
alertState,
startTime,
endTime,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
Expand Down Expand Up @@ -129,9 +137,11 @@ public void getAlertsByMonitorIds(
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime,
ActionListener<GetAlertsResponse> listener
) {

BoolQueryBuilder boolQueryBuilder = getBoolQueryBuilder(startTime, endTime);
org.opensearch.commons.alerting.action.GetAlertsRequest req =
new org.opensearch.commons.alerting.action.GetAlertsRequest(
table,
Expand All @@ -141,7 +151,8 @@ public void getAlertsByMonitorIds(
alertIndex,
monitorIds,
null,
null
null,
boolQueryBuilder
);

AlertingPluginInterface.INSTANCE.getAlerts((NodeClient) client, req, new ActionListener<>() {
Expand Down Expand Up @@ -178,6 +189,8 @@ public void getAlerts(
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime,
ActionListener<GetAlertsResponse> listener
) {
if (detectors.size() == 0) {
Expand All @@ -204,6 +217,8 @@ public void getAlerts(
table,
severityLevel,
alertState,
startTime,
endTime,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
Expand Down Expand Up @@ -246,7 +261,10 @@ private AlertDto mapAlertToAlertDto(Alert alert, String detectorId) {
public void getAlerts(List<String> alertIds,
Detector detector,
Table table,
Instant startTime,
Instant endTime,
ActionListener<org.opensearch.commons.alerting.action.GetAlertsResponse> actionListener) {
BoolQueryBuilder boolQueryBuilder = getBoolQueryBuilder(startTime, endTime);
GetAlertsRequest request = new GetAlertsRequest(
table,
"ALL",
Expand All @@ -255,7 +273,8 @@ public void getAlerts(List<String> alertIds,
DetectorMonitorConfig.getAllAlertsIndicesPattern(detector.getDetectorType()),
null,
null,
alertIds);
alertIds,
boolQueryBuilder);
AlertingPluginInterface.INSTANCE.getAlerts(
(NodeClient) client,
request, actionListener);
Expand Down Expand Up @@ -305,4 +324,17 @@ public void onFailure(Exception e) {
}

}

private static BoolQueryBuilder getBoolQueryBuilder(Instant startTime, Instant endTime) {
amsiglan marked this conversation as resolved.
Show resolved Hide resolved
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
if (startTime != null && endTime != null) {
long startTimeMillis = startTime.toEpochMilli();
long endTimeMillis = endTime.toEpochMilli();
QueryBuilder timeRangeQuery = QueryBuilders.rangeQuery("start_time")
.from(startTimeMillis) // Greater than or equal to start time
.to(endTimeMillis); // Less than or equal to end time
boolQueryBuilder.filter(timeRangeQuery);
}
return boolQueryBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package org.opensearch.securityanalytics.resthandler;

import java.io.IOException;
import java.time.DateTimeException;
import java.time.Instant;
import java.util.List;
import java.util.Locale;
import org.opensearch.client.node.NodeClient;
Expand Down Expand Up @@ -45,6 +47,26 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
int startIndex = request.paramAsInt("startIndex", 0);
String searchString = request.param("searchString", "");

Instant startTime = null;
String startTimeParam = request.param("startTime");
if (startTimeParam != null && !startTimeParam.isEmpty()) {
try {
startTime = Instant.ofEpochMilli(Long.parseLong(startTimeParam));
} catch (NumberFormatException | NullPointerException | DateTimeException e) {
startTime = Instant.now();
}
}

Instant endTime = null;
String endTimeParam = request.param("endTime");
if (endTimeParam != null && !endTimeParam.isEmpty()) {
try {
endTime = Instant.ofEpochMilli(Long.parseLong(endTimeParam));
} catch (NumberFormatException | NullPointerException | DateTimeException e) {
endTime = Instant.now();
}
}

Table table = new Table(
sortOrder,
sortString,
Expand All @@ -59,7 +81,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
detectorType,
table,
severityLevel,
alertState
alertState,
startTime,
endTime
);

return channel -> client.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ public void onResponse(GetDetectorResponse getDetectorResponse) {
request.getAlertIds(),
getDetectorResponse.getDetector(),
new Table("asc", "id", null, 10000, 0, null),
null,
null,
getAlertsResponseStepListener
);
getAlertsResponseStepListener.whenComplete(getAlertsResponse -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ protected void doExecute(Task task, GetAlertsRequest request, ActionListener<Get
request.getTable(),
request.getSeverityLevel(),
request.getAlertState(),
request.getStartTime(),
request.getEndTime(),
actionListener
);
} else {
Expand Down Expand Up @@ -135,6 +137,8 @@ public void onResponse(SearchResponse searchResponse) {
request.getTable(),
request.getSeverityLevel(),
request.getAlertState(),
request.getStartTime(),
request.getEndTime(),
actionListener
);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ public void testGetAlerts_success() {
);

doAnswer(invocation -> {
ActionListener l = invocation.getArgument(6);
ActionListener l = invocation.getArgument(8);
l.onResponse(getAlertsResponse);
return null;
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class));
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(ActionListener.class));

// Call getFindingsByDetectorId
Table table = new Table(
Expand All @@ -205,7 +205,8 @@ public void testGetAlerts_success() {
0,
null
);
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() {
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
assertEquals(2, (int)getAlertsResponse.getTotalAlerts());
Expand Down Expand Up @@ -258,10 +259,10 @@ public void testGetFindings_getFindingsByMonitorIdFailures() {
}).when(client).execute(eq(GetDetectorAction.INSTANCE), any(GetDetectorRequest.class), any(ActionListener.class));

doAnswer(invocation -> {
ActionListener l = invocation.getArgument(6);
ActionListener l = invocation.getArgument(8);
l.onFailure(new IllegalArgumentException("Error getting findings"));
return null;
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class));
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(ActionListener.class));

// Call getFindingsByDetectorId
Table table = new Table(
Expand All @@ -272,7 +273,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() {
0,
null
);
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() {
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
fail("this test should've failed");
Expand Down Expand Up @@ -307,7 +309,8 @@ public void testGetFindings_getDetectorFailure() {
0,
null
);
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() {
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
fail("this test should've failed");
Expand Down
Loading
Loading