From 15b00e5fa5277ced3647c4eab12ce62e83dd854d Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:54:28 -0700 Subject: [PATCH] CorrelationAlert model added (#631) (#679) * CorrelationALert model added * fix klint errors * address the comments * fix klint errors --------- (cherry picked from commit e060f5ea7059c698b8dc1ce31a225b91d8229860) Signed-off-by: Riya Saxena Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- .../commons/alerting/model/BaseAlert.kt | 208 ++++++++++++++++++ .../alerting/model/CorrelationAlert.kt | 141 ++++++++++++ .../commons/alerting/CorrelationAlertTests.kt | 130 +++++++++++ .../commons/alerting/TestHelpers.kt | 56 +++++ .../opensearch/commons/utils/TestHelpers.kt | 7 +- 5 files changed, 541 insertions(+), 1 deletion(-) create mode 100644 src/main/kotlin/org/opensearch/commons/alerting/model/BaseAlert.kt create mode 100644 src/main/kotlin/org/opensearch/commons/alerting/model/CorrelationAlert.kt create mode 100644 src/test/kotlin/org/opensearch/commons/alerting/CorrelationAlertTests.kt diff --git a/src/main/kotlin/org/opensearch/commons/alerting/model/BaseAlert.kt b/src/main/kotlin/org/opensearch/commons/alerting/model/BaseAlert.kt new file mode 100644 index 00000000..10c5a0ae --- /dev/null +++ b/src/main/kotlin/org/opensearch/commons/alerting/model/BaseAlert.kt @@ -0,0 +1,208 @@ +package org.opensearch.commons.alerting.model + +import org.opensearch.common.lucene.uid.Versions +import org.opensearch.commons.alerting.util.IndexUtils.Companion.NO_SCHEMA_VERSION +import org.opensearch.commons.alerting.util.instant +import org.opensearch.commons.alerting.util.optionalTimeField +import org.opensearch.commons.alerting.util.optionalUserField +import org.opensearch.commons.authuser.User +import org.opensearch.core.common.io.stream.StreamInput +import org.opensearch.core.common.io.stream.StreamOutput +import org.opensearch.core.common.io.stream.Writeable +import org.opensearch.core.xcontent.ToXContent +import org.opensearch.core.xcontent.XContentBuilder +import org.opensearch.core.xcontent.XContentParser +import org.opensearch.core.xcontent.XContentParserUtils +import java.io.IOException +import java.time.Instant + +/** CorrelationAlert and Alert can extend the UnifiedAlert class to inherit the common fields and behavior + * of UnifiedAlert class. + */ +open class BaseAlert( + open val id: String = Alert.NO_ID, + open val version: Long = Alert.NO_VERSION, + open val schemaVersion: Int = NO_SCHEMA_VERSION, + open val user: User?, + open val triggerName: String, + + // State will be later moved to this Class (after `monitorBasedAlerts` extend this Class) + open val state: Alert.State, + open val startTime: Instant, + open val endTime: Instant? = null, + open val acknowledgedTime: Instant? = null, + open val errorMessage: String? = null, + open val severity: String, + open val actionExecutionResults: List +) : Writeable, ToXContent { + + init { + if (errorMessage != null) { + require((state == Alert.State.DELETED) || (state == Alert.State.ERROR) || (state == Alert.State.AUDIT)) { + "Attempt to create an alert with an error in state: $state" + } + } + } + + @Throws(IOException::class) + constructor(sin: StreamInput) : this( + id = sin.readString(), + version = sin.readLong(), + schemaVersion = sin.readInt(), + user = if (sin.readBoolean()) { + User(sin) + } else { + null + }, + triggerName = sin.readString(), + state = sin.readEnum(Alert.State::class.java), + startTime = sin.readInstant(), + endTime = sin.readOptionalInstant(), + acknowledgedTime = sin.readOptionalInstant(), + errorMessage = sin.readOptionalString(), + severity = sin.readString(), + actionExecutionResults = sin.readList(::ActionExecutionResult) + ) + + fun isAcknowledged(): Boolean = (state == Alert.State.ACKNOWLEDGED) + + @Throws(IOException::class) + override fun writeTo(out: StreamOutput) { + out.writeString(id) + out.writeLong(version) + out.writeInt(schemaVersion) + out.writeBoolean(user != null) + user?.writeTo(out) + out.writeString(triggerName) + out.writeEnum(state) + out.writeInstant(startTime) + out.writeOptionalInstant(endTime) + out.writeOptionalInstant(acknowledgedTime) + out.writeOptionalString(errorMessage) + out.writeString(severity) + out.writeCollection(actionExecutionResults) + } + + companion object { + const val ALERT_ID_FIELD = "id" + const val SCHEMA_VERSION_FIELD = "schemaVersion" + const val ALERT_VERSION_FIELD = "version" + const val USER_FIELD = "user" + const val TRIGGER_NAME_FIELD = "triggerName" + const val STATE_FIELD = "state" + const val START_TIME_FIELD = "startTime" + const val END_TIME_FIELD = "endTime" + const val ACKNOWLEDGED_TIME_FIELD = "acknowledgedTime" + const val ERROR_MESSAGE_FIELD = "errorMessage" + const val SEVERITY_FIELD = "severity" + const val ACTION_EXECUTION_RESULTS_FIELD = "actionExecutionResults" + const val NO_ID = "" + const val NO_VERSION = Versions.NOT_FOUND + + @JvmStatic + @JvmOverloads + @Throws(IOException::class) + fun parse(xcp: XContentParser, version: Long = NO_VERSION): BaseAlert { + lateinit var id: String + var schemaVersion = NO_SCHEMA_VERSION + var version: Long = Versions.NOT_FOUND + var user: User? = null + lateinit var triggerName: String + lateinit var state: Alert.State + lateinit var startTime: Instant + lateinit var severity: String + var endTime: Instant? = null + var acknowledgedTime: Instant? = null + var errorMessage: String? = null + val actionExecutionResults: MutableList = mutableListOf() + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + val fieldName = xcp.currentName() + xcp.nextToken() + when (fieldName) { + USER_FIELD -> user = if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) null else User.parse(xcp) + ALERT_ID_FIELD -> id = xcp.text() + ALERT_VERSION_FIELD -> version = xcp.longValue() + SCHEMA_VERSION_FIELD -> schemaVersion = xcp.intValue() + TRIGGER_NAME_FIELD -> triggerName = xcp.text() + STATE_FIELD -> state = Alert.State.valueOf(xcp.text()) + ERROR_MESSAGE_FIELD -> errorMessage = xcp.textOrNull() + SEVERITY_FIELD -> severity = xcp.text() + ACTION_EXECUTION_RESULTS_FIELD -> { + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_ARRAY, + xcp.currentToken(), + xcp + ) + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + actionExecutionResults.add(ActionExecutionResult.parse(xcp)) + } + } + START_TIME_FIELD -> startTime = requireNotNull(xcp.instant()) + END_TIME_FIELD -> endTime = xcp.instant() + ACKNOWLEDGED_TIME_FIELD -> acknowledgedTime = xcp.instant() + } + } + + return BaseAlert( + id = id, + startTime = requireNotNull(startTime), + endTime = endTime, + state = requireNotNull(state), + version = version, + errorMessage = errorMessage, + actionExecutionResults = actionExecutionResults, + schemaVersion = schemaVersion, + user = user, + triggerName = requireNotNull(triggerName), + severity = severity, + acknowledgedTime = acknowledgedTime + ) + } + + @JvmStatic + @Throws(IOException::class) + fun readFrom(sin: StreamInput): Alert { + return Alert(sin) + } + } + + override fun toXContent(builder: XContentBuilder, params: ToXContent.Params): XContentBuilder { + return createXContentBuilder(builder, true) + } + + fun toXContentWithUser(builder: XContentBuilder): XContentBuilder { + return createXContentBuilder(builder, false) + } + + fun createXContentBuilder(builder: XContentBuilder, secure: Boolean): XContentBuilder { + if (!secure) { + builder.optionalUserField(USER_FIELD, user) + } + builder.field(ALERT_ID_FIELD, id) + .field(ALERT_VERSION_FIELD, version) + .field(SCHEMA_VERSION_FIELD, schemaVersion) + .field(TRIGGER_NAME_FIELD, triggerName) + .field(STATE_FIELD, state) + .field(ERROR_MESSAGE_FIELD, errorMessage) + .field(SEVERITY_FIELD, severity) + .field(ACTION_EXECUTION_RESULTS_FIELD, actionExecutionResults.toTypedArray()) + .optionalTimeField(START_TIME_FIELD, startTime) + .optionalTimeField(END_TIME_FIELD, endTime) + .optionalTimeField(ACKNOWLEDGED_TIME_FIELD, acknowledgedTime) + return builder + } + + open fun asTemplateArg(): Map { + return mapOf( + ACKNOWLEDGED_TIME_FIELD to acknowledgedTime?.toEpochMilli(), + ALERT_ID_FIELD to id, + ALERT_VERSION_FIELD to version, + END_TIME_FIELD to endTime?.toEpochMilli(), + ERROR_MESSAGE_FIELD to errorMessage, + SEVERITY_FIELD to severity, + START_TIME_FIELD to startTime.toEpochMilli(), + STATE_FIELD to state.toString(), + TRIGGER_NAME_FIELD to triggerName + ) + } +} diff --git a/src/main/kotlin/org/opensearch/commons/alerting/model/CorrelationAlert.kt b/src/main/kotlin/org/opensearch/commons/alerting/model/CorrelationAlert.kt new file mode 100644 index 00000000..aa8f7f85 --- /dev/null +++ b/src/main/kotlin/org/opensearch/commons/alerting/model/CorrelationAlert.kt @@ -0,0 +1,141 @@ +package org.opensearch.commons.alerting.model + +import org.opensearch.commons.authuser.User +import org.opensearch.core.common.io.stream.StreamInput +import org.opensearch.core.common.io.stream.StreamOutput +import org.opensearch.core.xcontent.XContentBuilder +import org.opensearch.core.xcontent.XContentParser +import org.opensearch.core.xcontent.XContentParserUtils +import java.io.IOException +import java.time.Instant + +class CorrelationAlert : BaseAlert { + + // CorrelationAlert-specific properties + val correlatedFindingIds: List + val correlationRuleId: String + val correlationRuleName: String + + constructor( + correlatedFindingIds: List, + correlationRuleId: String, + correlationRuleName: String, + id: String, + version: Long, + schemaVersion: Int, + user: User?, + triggerName: String, + state: Alert.State, + startTime: Instant, + endTime: Instant?, + acknowledgedTime: Instant?, + errorMessage: String?, + severity: String, + actionExecutionResults: List + ) : super( + id = id, + version = version, + schemaVersion = schemaVersion, + user = user, + triggerName = triggerName, + state = state, + startTime = startTime, + endTime = endTime, + acknowledgedTime = acknowledgedTime, + errorMessage = errorMessage, + severity = severity, + actionExecutionResults = actionExecutionResults + ) { + this.correlatedFindingIds = correlatedFindingIds + this.correlationRuleId = correlationRuleId + this.correlationRuleName = correlationRuleName + } + + @Throws(IOException::class) + constructor(sin: StreamInput) : super(sin) { + correlatedFindingIds = sin.readStringList() + correlationRuleId = sin.readString() + correlationRuleName = sin.readString() + } + + // Override to include CorrelationAlert specific fields + fun toXContent(builder: XContentBuilder): XContentBuilder { + builder.startObject() + .startArray(CORRELATED_FINDING_IDS) + correlatedFindingIds.forEach { id -> + builder.value(id) + } + builder.endArray() + .field(CORRELATION_RULE_ID, correlationRuleId) + .field(CORRELATION_RULE_NAME, correlationRuleName) + super.toXContentWithUser(builder) + builder.endObject() + return builder + } + + @Throws(IOException::class) + override fun writeTo(out: StreamOutput) { + super.writeTo(out) + out.writeStringCollection(correlatedFindingIds) + out.writeString(correlationRuleId) + out.writeString(correlationRuleName) + } + override fun asTemplateArg(): Map { + val superTemplateArgs = super.asTemplateArg() + val correlationSpecificArgs = mapOf( + CORRELATED_FINDING_IDS to correlatedFindingIds, + CORRELATION_RULE_ID to correlationRuleId, + CORRELATION_RULE_NAME to correlationRuleName + ) + return superTemplateArgs + correlationSpecificArgs + } + companion object { + const val CORRELATED_FINDING_IDS = "correlatedFindingIds" + const val CORRELATION_RULE_ID = "correlationRuleId" + const val CORRELATION_RULE_NAME = "correlationRuleName" + + @JvmStatic + @Throws(IOException::class) + fun parse(xcp: XContentParser, id: String = NO_ID, version: Long = NO_VERSION): CorrelationAlert { + // Parse additional CorrelationAlert-specific fields + val correlatedFindingIds: MutableList = mutableListOf() + var correlationRuleId: String? = null + var correlationRuleName: String? = null + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp) + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + val fieldName = xcp.currentName() + xcp.nextToken() + + when (fieldName) { + CORRELATED_FINDING_IDS -> { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp) + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + correlatedFindingIds.add(xcp.text()) + } + } + CORRELATION_RULE_ID -> correlationRuleId = xcp.text() + CORRELATION_RULE_NAME -> correlationRuleName = xcp.text() + } + } + + val unifiedAlert = parse(xcp, version) + return CorrelationAlert( + correlatedFindingIds = correlatedFindingIds, + correlationRuleId = requireNotNull(correlationRuleId), + correlationRuleName = requireNotNull(correlationRuleName), + id = requireNotNull(unifiedAlert.id), + version = requireNotNull(unifiedAlert.version), + schemaVersion = requireNotNull(unifiedAlert.schemaVersion), + user = unifiedAlert.user, + triggerName = requireNotNull(unifiedAlert.triggerName), + state = requireNotNull(unifiedAlert.state), + startTime = requireNotNull(unifiedAlert.startTime), + endTime = unifiedAlert.endTime, + acknowledgedTime = unifiedAlert.acknowledgedTime, + errorMessage = unifiedAlert.errorMessage, + severity = requireNotNull(unifiedAlert.severity), + actionExecutionResults = unifiedAlert.actionExecutionResults + ) + } + } +} diff --git a/src/test/kotlin/org/opensearch/commons/alerting/CorrelationAlertTests.kt b/src/test/kotlin/org/opensearch/commons/alerting/CorrelationAlertTests.kt new file mode 100644 index 00000000..fdc7f068 --- /dev/null +++ b/src/test/kotlin/org/opensearch/commons/alerting/CorrelationAlertTests.kt @@ -0,0 +1,130 @@ +package org.opensearch.commons.alerting + +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.opensearch.common.xcontent.LoggingDeprecationHandler +import org.opensearch.common.xcontent.XContentHelper +import org.opensearch.commons.alerting.model.Alert +import org.opensearch.commons.alerting.model.CorrelationAlert +import org.opensearch.commons.utils.getJsonString +import org.opensearch.commons.utils.recreateObject +import org.opensearch.core.common.bytes.BytesArray +import org.opensearch.core.common.bytes.BytesReference +import org.opensearch.core.common.io.stream.InputStreamStreamInput +import org.opensearch.core.xcontent.NamedXContentRegistry +import java.time.temporal.ChronoUnit + +class CorrelationAlertTests { + + @Test + fun `test correlation alert as template args`() { + // Create sample data for CorrelationAlert + val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE) + + // Generate template args using asTemplateArg() function + val templateArgs = createCorrelationAlertTemplateArgs(correlationAlert) + + assertEquals( + templateArgs["correlatedFindingIds"], + correlationAlert.correlatedFindingIds, + "Template args correlatedFindingIds does not match" + ) + assertEquals( + templateArgs["correlationRuleId"], + correlationAlert.correlationRuleId, + "Template args correlationRuleId does not match" + ) + assertEquals( + templateArgs["correlationRuleName"], + correlationAlert.correlationRuleName, + "Template args correlationRuleName does not match" + ) + + // Verify inherited properties from BaseAlert + assertEquals(templateArgs["id"], correlationAlert.id, "alertId1") + assertEquals(templateArgs["version"], correlationAlert.version, "Template args version does not match") + assertEquals(templateArgs["user"], correlationAlert.user, "Template args user does not match") + assertEquals( + templateArgs["triggerName"], + correlationAlert.triggerName, + "Template args triggerName does not match" + ) + assertEquals(templateArgs["state"], correlationAlert.state, "Template args state does not match") + assertEquals(templateArgs["startTime"], correlationAlert.startTime, "Template args startTime does not match") + assertEquals(templateArgs["endTime"], correlationAlert.endTime, "Template args endTime does not match") + assertEquals( + templateArgs["acknowledgedTime"], + correlationAlert.acknowledgedTime, + "Template args acknowledgedTime does not match" + ) + assertEquals( + templateArgs["errorMessage"], + correlationAlert.errorMessage, + "Template args errorMessage does not match" + ) + assertEquals(templateArgs["severity"], correlationAlert.severity, "Template args severity does not match") + assertEquals( + templateArgs["actionExecutionResults"], + correlationAlert.actionExecutionResults, + "Template args actionExecutionResults does not match" + ) + } + + @Test + fun `test alert acknowledged`() { + val ackCorrelationAlert = randomCorrelationAlert("alertId1", Alert.State.ACKNOWLEDGED) + Assertions.assertTrue(ackCorrelationAlert.isAcknowledged(), "Alert is not acknowledged") + + val activeCorrelationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE) + Assertions.assertFalse(activeCorrelationAlert.isAcknowledged(), "Alert is acknowledged") + } + + @Test + fun `test correlation parse function`() { + // Generate a random CorrelationAlert object + val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE) + val correlationAlertString = getJsonString(correlationAlert) + + // Convert the JSON string to a BytesReference + val serializedBytes: BytesReference = BytesArray(correlationAlertString.toByteArray(Charsets.UTF_8)) + + // Deserialize the BytesReference into a CorrelationAlert object using the parse function + val recreatedAlert: CorrelationAlert = InputStreamStreamInput(serializedBytes.streamInput()).use { streamInput -> + XContentHelper.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, serializedBytes).use { parser -> + parser.nextToken() // Move to the start of the content + CorrelationAlert.parse(parser) + } + } + + // Assert that the deserialized object matches the original object + assertEquals(correlationAlert.correlatedFindingIds, recreatedAlert.correlatedFindingIds) + assertEquals(correlationAlert.correlationRuleId, recreatedAlert.correlationRuleId) + assertEquals(correlationAlert.correlationRuleName, recreatedAlert.correlationRuleName) + assertEquals(correlationAlert.triggerName, recreatedAlert.triggerName) + assertEquals(correlationAlert.state, recreatedAlert.state) + val expectedStartTime = correlationAlert.startTime.truncatedTo(ChronoUnit.MILLIS) + val actualStartTime = recreatedAlert.startTime.truncatedTo(ChronoUnit.MILLIS) + assertEquals(expectedStartTime, actualStartTime) + assertEquals(correlationAlert.severity, recreatedAlert.severity) + assertEquals(correlationAlert.id, recreatedAlert.id) + assertEquals(correlationAlert.actionExecutionResults, recreatedAlert.actionExecutionResults) + } + + @Test + fun `Feature Correlation Alert serialize and deserialize should be equal`() { + val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE) + val recreatedAlert = recreateObject(correlationAlert) { CorrelationAlert(it) } + assertEquals(correlationAlert.correlatedFindingIds, recreatedAlert.correlatedFindingIds) + assertEquals(correlationAlert.correlationRuleId, recreatedAlert.correlationRuleId) + assertEquals(correlationAlert.correlationRuleName, recreatedAlert.correlationRuleName) + assertEquals(correlationAlert.triggerName, recreatedAlert.triggerName) + assertEquals(correlationAlert.state, recreatedAlert.state) + val expectedStartTime = correlationAlert.startTime.truncatedTo(ChronoUnit.MILLIS) + val actualStartTime = recreatedAlert.startTime.truncatedTo(ChronoUnit.MILLIS) + assertEquals(expectedStartTime, actualStartTime) + assertEquals(correlationAlert.severity, recreatedAlert.severity) + assertEquals(correlationAlert.id, recreatedAlert.id) + assertEquals(correlationAlert.actionExecutionResults, recreatedAlert.actionExecutionResults) + } +} diff --git a/src/test/kotlin/org/opensearch/commons/alerting/TestHelpers.kt b/src/test/kotlin/org/opensearch/commons/alerting/TestHelpers.kt index ca193224..85489316 100644 --- a/src/test/kotlin/org/opensearch/commons/alerting/TestHelpers.kt +++ b/src/test/kotlin/org/opensearch/commons/alerting/TestHelpers.kt @@ -20,11 +20,13 @@ import org.opensearch.commons.alerting.aggregation.bucketselectorext.BucketSelec import org.opensearch.commons.alerting.model.ActionExecutionResult import org.opensearch.commons.alerting.model.AggregationResultBucket import org.opensearch.commons.alerting.model.Alert +import org.opensearch.commons.alerting.model.BaseAlert import org.opensearch.commons.alerting.model.BucketLevelTrigger import org.opensearch.commons.alerting.model.ChainedAlertTrigger import org.opensearch.commons.alerting.model.ChainedMonitorFindings import org.opensearch.commons.alerting.model.ClusterMetricsInput import org.opensearch.commons.alerting.model.CompositeInput +import org.opensearch.commons.alerting.model.CorrelationAlert import org.opensearch.commons.alerting.model.Delegate import org.opensearch.commons.alerting.model.DocLevelMonitorInput import org.opensearch.commons.alerting.model.DocLevelQuery @@ -601,3 +603,57 @@ fun randomFinding( timestamp = timestamp ) } + +fun randomCorrelationAlert( + id: String, + state: Alert.State +): CorrelationAlert { + val correlatedFindingIds = listOf("finding1", "finding2") + val correlationRuleId = "rule1" + val correlationRuleName = "Rule 1" + val id = id + val version = 1L + val schemaVersion = 1 + val user = randomUser() + val triggerName = "Trigger 1" + val state = state + val startTime = Instant.now() + val endTime: Instant? = null + val acknowledgedTime: Instant? = null + val errorMessage: String? = null + val severity = "high" + val actionExecutionResults = listOf(randomActionExecutionResult()) + + return CorrelationAlert( + correlatedFindingIds, correlationRuleId, correlationRuleName, + id, version, schemaVersion, user, triggerName, state, + startTime, endTime, acknowledgedTime, errorMessage, severity, + actionExecutionResults + ) +} + +fun createUnifiedAlertTemplateArgs(unifiedAlert: BaseAlert): Map { + return mapOf( + BaseAlert.ALERT_ID_FIELD to unifiedAlert.id, + BaseAlert.ALERT_VERSION_FIELD to unifiedAlert.version, + BaseAlert.SCHEMA_VERSION_FIELD to unifiedAlert.schemaVersion, + BaseAlert.USER_FIELD to unifiedAlert.user, + BaseAlert.TRIGGER_NAME_FIELD to unifiedAlert.triggerName, + BaseAlert.STATE_FIELD to unifiedAlert.state, + BaseAlert.START_TIME_FIELD to unifiedAlert.startTime, + BaseAlert.END_TIME_FIELD to unifiedAlert.endTime, + BaseAlert.ACKNOWLEDGED_TIME_FIELD to unifiedAlert.acknowledgedTime, + BaseAlert.ERROR_MESSAGE_FIELD to unifiedAlert.errorMessage, + BaseAlert.SEVERITY_FIELD to unifiedAlert.severity, + BaseAlert.ACTION_EXECUTION_RESULTS_FIELD to unifiedAlert.actionExecutionResults + ) +} + +fun createCorrelationAlertTemplateArgs(correlationAlert: CorrelationAlert): Map { + val unifiedAlertTemplateArgs = createUnifiedAlertTemplateArgs(correlationAlert) + return unifiedAlertTemplateArgs + mapOf( + "correlatedFindingIds" to correlationAlert.correlatedFindingIds, + "correlationRuleId" to correlationAlert.correlationRuleId, + "correlationRuleName" to correlationAlert.correlationRuleName + ) +} diff --git a/src/test/kotlin/org/opensearch/commons/utils/TestHelpers.kt b/src/test/kotlin/org/opensearch/commons/utils/TestHelpers.kt index 1170851b..9481a1eb 100644 --- a/src/test/kotlin/org/opensearch/commons/utils/TestHelpers.kt +++ b/src/test/kotlin/org/opensearch/commons/utils/TestHelpers.kt @@ -7,6 +7,7 @@ package org.opensearch.commons.utils import org.opensearch.common.xcontent.XContentFactory import org.opensearch.common.xcontent.XContentType +import org.opensearch.commons.alerting.model.CorrelationAlert import org.opensearch.core.xcontent.DeprecationHandler import org.opensearch.core.xcontent.NamedXContentRegistry import org.opensearch.core.xcontent.ToXContent @@ -16,7 +17,11 @@ import java.io.ByteArrayOutputStream fun getJsonString(xContent: ToXContent): String { ByteArrayOutputStream().use { byteArrayOutputStream -> val builder = XContentFactory.jsonBuilder(byteArrayOutputStream) - xContent.toXContent(builder, ToXContent.EMPTY_PARAMS) + if (xContent is CorrelationAlert) { + xContent.toXContent(builder) + } else { + xContent.toXContent(builder, ToXContent.EMPTY_PARAMS) + } builder.close() return byteArrayOutputStream.toString("UTF8") }