Skip to content

Commit

Permalink
Fix took time, filter out disabled detectors
Browse files Browse the repository at this point in the history
Signed-off-by: Chase Engelbrecht <[email protected]>
  • Loading branch information
engechas committed Feb 13, 2024
1 parent b30b9e9 commit 2624294
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class StreamingDetectorMetadataConverter {
public List<StreamingDetectorMetadata> convert(final List<Detector> detectors, final Map<String, List<DocData>> indexToDocData) {
return detectors.stream()
.peek(StreamingDetectorValidators::validateDetector)
.filter(Detector::getEnabled)
.filter(Detector::isStreamingDetector)
.filter(detector -> doesDetectorHaveIndexAsInput(detector, indexToDocData.keySet()))
.map(detector -> createStreamingDetectorMetadata(detector, indexToDocData))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public TransportExecuteStreamingDetectorsAction(final TransportService transport
*/
@Override
protected void doExecute(final Task task, final BulkRequest bulkRequest, final ActionListener<BulkResponse> listener) {
final long operationStartTime = System.currentTimeMillis();

if (!validateUser()) {
listener.onFailure(SecurityAnalyticsException.wrap(
new OpenSearchStatusException("User is not authorized to to perform this action. Contact administrator", RestStatus.FORBIDDEN)
Expand All @@ -128,7 +130,7 @@ protected void doExecute(final Task task, final BulkRequest bulkRequest, final A
@Override
public void onResponse(final BulkResponse bulkResponse) {
logDuration(startTime, "Execute BulkRequest");
identifyDetectors(bulkRequest, bulkResponse, listener);
identifyDetectors(bulkRequest, bulkResponse, listener, operationStartTime);
}

@Override
Expand All @@ -145,7 +147,8 @@ private boolean validateUser() {
return user == null || isAdmin(user);
}

private void identifyDetectors(final BulkRequest bulkRequest, final BulkResponse bulkResponse, final ActionListener<BulkResponse> listener) {
private void identifyDetectors(final BulkRequest bulkRequest, final BulkResponse bulkResponse,
final ActionListener<BulkResponse> listener, final long operationStartTime) {
final Map<String, List<DocData>> indexToDocData = indexNameToDocDataConverter.convert(bulkRequest, bulkResponse);
final SearchRequest listDetectorsRequest = getListDetectorsRequest();

Expand All @@ -160,21 +163,21 @@ public void onResponse(final SearchResponse searchResponse) {
detectors = DetectorUtils.getDetectors(searchResponse, xContentRegistry);
} catch (final IOException e) {
handleAllDetectorsFailure(bulkResponse, indexToDocData, e);
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
return;
}

getMonitors(indexToDocData, detectors, listener, bulkResponse);
getMonitors(indexToDocData, detectors, listener, bulkResponse, operationStartTime);
}

@Override
public void onFailure(final Exception e) {
if (e instanceof IndexNotFoundException) {
log.warn("No detectors configured, skipping streaming detectors workflow");
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
} else {
handleAllDetectorsFailure(bulkResponse, indexToDocData, e);
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
}
}
});
Expand All @@ -192,11 +195,12 @@ private SearchRequest getListDetectorsRequest() {
private void getMonitors(final Map<String, List<DocData>> indexToDocData,
final List<Detector> detectors,
final ActionListener<BulkResponse> listener,
final BulkResponse bulkResponse) {
final BulkResponse bulkResponse,
final long operationStartTime) {
final List<StreamingDetectorMetadata> streamingDetectors = streamingDetectorMetadataConverter.convert(detectors, indexToDocData);
if (streamingDetectors.isEmpty()) {
log.debug("No streaming detectors identified for incoming data. Skipping streaming detectors workflow");
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
return;
}

Expand All @@ -212,7 +216,7 @@ private void getMonitors(final Map<String, List<DocData>> indexToDocData,
@Override
public void onResponse(final GetMonitorResponse getMonitorResponse) {
populateWorkflowIdToMetadata(monitorId, getMonitorResponse, monitorIdToMetadata, getMonitorCounter,
listener, bulkResponse, startTime);
listener, bulkResponse, startTime, operationStartTime);
}

@Override
Expand All @@ -221,7 +225,7 @@ public void onFailure(final Exception e) {

getMonitorCounter.incrementAndGet();
if (getMonitorCounter.get() == monitorIdToMetadata.size()) {
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
}
}
});
Expand All @@ -234,15 +238,16 @@ private void populateWorkflowIdToMetadata(final String monitorId,
final AtomicInteger getMonitorCounter,
final ActionListener<BulkResponse> listener,
final BulkResponse bulkResponse,
final long startTime) {
final long startTime,
final long operationStartTime) {
final StreamingDetectorMetadata metadata = monitorIdToMetadata.get(monitorId);
if (isMonitorValidForStreaming(getMonitorResponse)) {
populateQueryFields(getMonitorResponse.getMonitor(), metadata);

getMonitorCounter.incrementAndGet();
if (getMonitorCounter.get() == monitorIdToMetadata.size()) {
logDuration(startTime, "Get and Populate Query Fields");
executeWorkflows(monitorIdToMetadata.values(), listener, bulkResponse);
executeWorkflows(monitorIdToMetadata.values(), listener, bulkResponse, operationStartTime);
}
} else {
final String errorMsg = String.format("Monitor with ID %s is invalid for streaming.", monitorId);
Expand All @@ -251,7 +256,7 @@ private void populateWorkflowIdToMetadata(final String monitorId,

getMonitorCounter.incrementAndGet();
if (getMonitorCounter.get() == monitorIdToMetadata.size()) {
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
}
}
}
Expand All @@ -272,20 +277,21 @@ private void populateQueryFields(final Monitor monitor, final StreamingDetectorM

private void executeWorkflows(final Collection<StreamingDetectorMetadata> streamingDetectorMetadata,
final ActionListener<BulkResponse> listener,
final BulkResponse bulkResponse) {
final BulkResponse bulkResponse,
final long operationStartTime) {
final AtomicInteger workflowExecutionCounter = new AtomicInteger(0);
final long startTime = System.currentTimeMillis();
streamingDetectorMetadata.forEach(metadata -> {
final ExecuteStreamingWorkflowRequest executeWorkflowRequest = executeStreamingWorkflowRequestConverter.convert(metadata);
executeWorkflow(executeWorkflowRequest, metadata, workflowExecutionCounter, streamingDetectorMetadata.size(),
listener, bulkResponse, startTime);
listener, bulkResponse, startTime, operationStartTime);
});
}

private void executeWorkflow(final ExecuteStreamingWorkflowRequest executeWorkflowRequest, final StreamingDetectorMetadata metadata,
final AtomicInteger workflowExecutionCounter, final int workflowCount,
final ActionListener<BulkResponse> listener, final BulkResponse bulkResponse,
final long startTime) {
final long startTime, final long operationStartTime) {
AlertingPluginInterface.INSTANCE.executeStreamingWorkflow((NodeClient) client, executeWorkflowRequest, new ActionListener<>() {
@Override
public void onResponse(final ExecuteStreamingWorkflowResponse executeStreamingWorkflowResponse) {
Expand All @@ -294,7 +300,7 @@ public void onResponse(final ExecuteStreamingWorkflowResponse executeStreamingWo

if (workflowExecutionCounter.get() == workflowCount) {
logDuration(startTime, "Execute Workflows");
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
}
}

Expand All @@ -305,12 +311,17 @@ public void onFailure(final Exception e) {

workflowExecutionCounter.incrementAndGet();
if (workflowExecutionCounter.get() == workflowCount) {
listener.onResponse(bulkResponse);
listener.onResponse(recreateBulkResponseWithCorrectTookMillis(operationStartTime, bulkResponse));
}
}
});
}

private BulkResponse recreateBulkResponseWithCorrectTookMillis(final long startTime, final BulkResponse bulkResponse) {
final long tookMillis = System.currentTimeMillis() - startTime;
return new BulkResponse(bulkResponse.getItems(), tookMillis, bulkResponse.getIngestTookInMillis());
}

private void handleAllDetectorsFailure(final BulkResponse bulkResponse, final Map<String, List<DocData>> indexToDocData,
final Exception exception) {
log.error("Failed to run all detectors", exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TransportIndexDetectorAction extends HandledTransportAction<IndexDetectorRequest, IndexDetectorResponse> implements SecureTransportAction {

Expand Down Expand Up @@ -391,7 +392,7 @@ public void onFailure(Exception e) {
actionListener.onFailure(e);
}
},
enableStreamingDetectors);
isDetectorEligibleForStreaming(monitorResponses));
} else {
actionListener.onResponse(monitorResponses);
}
Expand Down Expand Up @@ -631,7 +632,13 @@ public void onFailure(Exception e) {
listener.onFailure(e);
}
},
enableStreamingDetectors);
isDetectorEligibleForStreaming(
Stream.concat(
addNewMonitorsResponse.stream(),
updateMonitorResponse.stream()
).collect(Collectors.toList())
)
);
}
}

Expand Down Expand Up @@ -1031,7 +1038,6 @@ void createDetector() {
request.getDetector().setFindingsIndex(DetectorMonitorConfig.getFindingsIndex(ruleTopic));
request.getDetector().setFindingsIndexPattern(DetectorMonitorConfig.getFindingsIndexPattern(ruleTopic));
request.getDetector().setRuleIndex(DetectorMonitorConfig.getRuleIndex(ruleTopic));
request.getDetector().setStreamingDetector(enableStreamingDetectors);

User originalContextUser = this.user;
log.debug("user from original context is {}", originalContextUser);
Expand All @@ -1047,6 +1053,7 @@ public void onResponse(AcknowledgedResponse acknowledgedResponse) {
initRuleIndexAndImportRules(request, new ActionListener<>() {
@Override
public void onResponse(List<IndexMonitorResponse> monitorResponses) {
request.getDetector().setStreamingDetector(isDetectorEligibleForStreaming(monitorResponses));
request.getDetector().setMonitorIds(getMonitorIds(monitorResponses));
request.getDetector().setRuleIdMonitorIdMap(mapMonitorIds(monitorResponses));
try {
Expand Down Expand Up @@ -1144,7 +1151,6 @@ void onGetResponse(Detector currentDetector, User user) {
request.getDetector().setFindingsIndexPattern(DetectorMonitorConfig.getFindingsIndexPattern(ruleTopic));
request.getDetector().setRuleIndex(DetectorMonitorConfig.getRuleIndex(ruleTopic));
request.getDetector().setUser(user);
request.getDetector().setStreamingDetector(enableStreamingDetectors);

if (!detector.getInputs().isEmpty()) {
try {
Expand All @@ -1154,6 +1160,7 @@ public void onResponse(AcknowledgedResponse acknowledgedResponse) {
initRuleIndexAndImportRules(request, new ActionListener<>() {
@Override
public void onResponse(List<IndexMonitorResponse> monitorResponses) {
request.getDetector().setStreamingDetector(isDetectorEligibleForStreaming(monitorResponses));
request.getDetector().setMonitorIds(getMonitorIds(monitorResponses));
request.getDetector().setRuleIdMonitorIdMap(mapMonitorIds(monitorResponses));
try {
Expand Down Expand Up @@ -1577,4 +1584,13 @@ private void setEnabledWorkflowUsage(boolean enabledWorkflowUsage) {
private void setEnableStreamingDetectors(boolean enableStreamingDetectors) {
this.enableStreamingDetectors = enableStreamingDetectors;
}

private boolean isDetectorEligibleForStreaming(final List<IndexMonitorResponse> indexMonitorResponses) {
// Only doc level monitors are supported for streaming. If any of the monitors associated with the detector are
// non-doc level monitors, then the detector is not marked as streaming enabled.
return enableStreamingDetectors && indexMonitorResponses.stream()
.map(IndexMonitorResponse::getMonitor)
.map(Monitor::getMonitorType)
.allMatch(type -> type == MonitorType.DOC_LEVEL_MONITOR);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public void setup() {
when(detector.getInputs()).thenReturn(List.of(detectorInput));
when(detector.getWorkflowIds()).thenReturn(List.of(WORKFLOW_ID));
when(detector.getMonitorIds()).thenReturn(List.of(MONITOR_ID));
when(detector.getEnabled()).thenReturn(true);
when(detector.isStreamingDetector()).thenReturn(true);
when(detectorInput.getIndices()).thenReturn(List.of(INDEX_NAME));
}
Expand All @@ -60,6 +61,13 @@ public void testInvalidDetectorThrows() {
assertThrows(SecurityAnalyticsException.class, () -> converter.convert(List.of(detector), Collections.emptyMap()));
}

public void testFiltersDisabledDetectors() {
when(detector.getEnabled()).thenReturn(false);

final List<StreamingDetectorMetadata> result = converter.convert(List.of(detector), getIndexToDocData(Set.of(INDEX_NAME)));
assertTrue(result.isEmpty());
}

public void testFiltersNonStreamingDetectors() {
when(detector.isStreamingDetector()).thenReturn(false);

Expand Down

0 comments on commit 2624294

Please sign in to comment.