Skip to content

Commit

Permalink
bug fixes for correlation Alerts (#670)
Browse files Browse the repository at this point in the history
* bug fixes for correlation Alerts

Signed-off-by: Riya Saxena <[email protected]>

* fixing the tests

Signed-off-by: Riya Saxena <[email protected]>

---------

Signed-off-by: Riya Saxena <[email protected]>
  • Loading branch information
riysaxen-amzn authored Jun 10, 2024
1 parent acaa844 commit 480590d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 73 deletions.
26 changes: 13 additions & 13 deletions src/main/kotlin/org/opensearch/commons/alerting/model/BaseAlert.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package org.opensearch.commons.alerting.model
import org.opensearch.common.lucene.uid.Versions
import org.opensearch.commons.alerting.util.IndexUtils.Companion.NO_SCHEMA_VERSION
import org.opensearch.commons.alerting.util.instant
import org.opensearch.commons.alerting.util.optionalTimeField
import org.opensearch.commons.alerting.util.optionalUserField
import org.opensearch.commons.authuser.User
import org.opensearch.core.common.io.stream.StreamInput
Expand Down Expand Up @@ -85,17 +84,17 @@ open class BaseAlert(

companion object {
const val ALERT_ID_FIELD = "id"
const val SCHEMA_VERSION_FIELD = "schemaVersion"
const val SCHEMA_VERSION_FIELD = "schema_version"
const val ALERT_VERSION_FIELD = "version"
const val USER_FIELD = "user"
const val TRIGGER_NAME_FIELD = "triggerName"
const val TRIGGER_NAME_FIELD = "trigger_name"
const val STATE_FIELD = "state"
const val START_TIME_FIELD = "startTime"
const val END_TIME_FIELD = "endTime"
const val ACKNOWLEDGED_TIME_FIELD = "acknowledgedTime"
const val ERROR_MESSAGE_FIELD = "errorMessage"
const val START_TIME_FIELD = "start_time"
const val END_TIME_FIELD = "end_time"
const val ACKNOWLEDGED_TIME_FIELD = "acknowledged_time"
const val ERROR_MESSAGE_FIELD = "error_message"
const val SEVERITY_FIELD = "severity"
const val ACTION_EXECUTION_RESULTS_FIELD = "actionExecutionResults"
const val ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results"
const val NO_ID = ""
const val NO_VERSION = Versions.NOT_FOUND

Expand Down Expand Up @@ -138,7 +137,7 @@ open class BaseAlert(
}
}
START_TIME_FIELD -> startTime = requireNotNull(xcp.instant())
END_TIME_FIELD -> endTime = xcp.instant()
END_TIME_FIELD -> endTime = requireNotNull(xcp.instant())
ACKNOWLEDGED_TIME_FIELD -> acknowledgedTime = xcp.instant()
}
}
Expand Down Expand Up @@ -178,17 +177,18 @@ open class BaseAlert(
if (!secure) {
builder.optionalUserField(USER_FIELD, user)
}
builder.field(ALERT_ID_FIELD, id)
builder
.field(ALERT_ID_FIELD, id)
.field(ALERT_VERSION_FIELD, version)
.field(SCHEMA_VERSION_FIELD, schemaVersion)
.field(TRIGGER_NAME_FIELD, triggerName)
.field(STATE_FIELD, state)
.field(ERROR_MESSAGE_FIELD, errorMessage)
.field(SEVERITY_FIELD, severity)
.field(ACTION_EXECUTION_RESULTS_FIELD, actionExecutionResults.toTypedArray())
.optionalTimeField(START_TIME_FIELD, startTime)
.optionalTimeField(END_TIME_FIELD, endTime)
.optionalTimeField(ACKNOWLEDGED_TIME_FIELD, acknowledgedTime)
.field(START_TIME_FIELD, startTime)
.field(END_TIME_FIELD, endTime)
.field(ACKNOWLEDGED_TIME_FIELD, acknowledgedTime)
return builder
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.opensearch.commons.alerting.model
import org.opensearch.commons.authuser.User
import org.opensearch.core.common.io.stream.StreamInput
import org.opensearch.core.common.io.stream.StreamOutput
import org.opensearch.core.xcontent.ToXContent
import org.opensearch.core.xcontent.XContentBuilder
import org.opensearch.core.xcontent.XContentParser
import org.opensearch.core.xcontent.XContentParserUtils
Expand Down Expand Up @@ -59,7 +60,7 @@ class CorrelationAlert : BaseAlert {
}

// Override to include CorrelationAlert specific fields
fun toXContent(builder: XContentBuilder): XContentBuilder {
override fun toXContent(builder: XContentBuilder, params: ToXContent.Params): XContentBuilder {
builder.startObject()
.startArray(CORRELATED_FINDING_IDS)
correlatedFindingIds.forEach { id ->
Expand Down Expand Up @@ -90,9 +91,9 @@ class CorrelationAlert : BaseAlert {
return superTemplateArgs + correlationSpecificArgs
}
companion object {
const val CORRELATED_FINDING_IDS = "correlatedFindingIds"
const val CORRELATION_RULE_ID = "correlationRuleId"
const val CORRELATION_RULE_NAME = "correlationRuleName"
const val CORRELATED_FINDING_IDS = "correlated_finding_ids"
const val CORRELATION_RULE_ID = "correlation_rule_id"
const val CORRELATION_RULE_NAME = "correlation_rule_name"

@JvmStatic
@Throws(IOException::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@ package org.opensearch.commons.alerting
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.opensearch.common.xcontent.LoggingDeprecationHandler
import org.opensearch.common.xcontent.XContentHelper
import org.opensearch.commons.alerting.model.Alert
import org.opensearch.commons.alerting.model.CorrelationAlert
import org.opensearch.commons.utils.getJsonString
import org.opensearch.commons.utils.recreateObject
import org.opensearch.core.common.bytes.BytesArray
import org.opensearch.core.common.bytes.BytesReference
import org.opensearch.core.common.io.stream.InputStreamStreamInput
import org.opensearch.core.xcontent.NamedXContentRegistry
import java.time.temporal.ChronoUnit

class CorrelationAlertTests {
Expand All @@ -26,17 +19,17 @@ class CorrelationAlertTests {
val templateArgs = createCorrelationAlertTemplateArgs(correlationAlert)

assertEquals(
templateArgs["correlatedFindingIds"],
templateArgs["correlated_finding_ids"],
correlationAlert.correlatedFindingIds,
"Template args correlatedFindingIds does not match"
)
assertEquals(
templateArgs["correlationRuleId"],
templateArgs["correlation_rule_id"],
correlationAlert.correlationRuleId,
"Template args correlationRuleId does not match"
)
assertEquals(
templateArgs["correlationRuleName"],
templateArgs["correlation_rule_name"],
correlationAlert.correlationRuleName,
"Template args correlationRuleName does not match"
)
Expand All @@ -46,26 +39,26 @@ class CorrelationAlertTests {
assertEquals(templateArgs["version"], correlationAlert.version, "Template args version does not match")
assertEquals(templateArgs["user"], correlationAlert.user, "Template args user does not match")
assertEquals(
templateArgs["triggerName"],
templateArgs["trigger_name"],
correlationAlert.triggerName,
"Template args triggerName does not match"
)
assertEquals(templateArgs["state"], correlationAlert.state, "Template args state does not match")
assertEquals(templateArgs["startTime"], correlationAlert.startTime, "Template args startTime does not match")
assertEquals(templateArgs["endTime"], correlationAlert.endTime, "Template args endTime does not match")
assertEquals(templateArgs["start_time"], correlationAlert.startTime, "Template args startTime does not match")
assertEquals(templateArgs["end_time"], correlationAlert.endTime, "Template args endTime does not match")
assertEquals(
templateArgs["acknowledgedTime"],
templateArgs["acknowledged_time"],
correlationAlert.acknowledgedTime,
"Template args acknowledgedTime does not match"
)
assertEquals(
templateArgs["errorMessage"],
templateArgs["error_message"],
correlationAlert.errorMessage,
"Template args errorMessage does not match"
)
assertEquals(templateArgs["severity"], correlationAlert.severity, "Template args severity does not match")
assertEquals(
templateArgs["actionExecutionResults"],
templateArgs["action_execution_results"],
correlationAlert.actionExecutionResults,
"Template args actionExecutionResults does not match"
)
Expand All @@ -80,37 +73,6 @@ class CorrelationAlertTests {
Assertions.assertFalse(activeCorrelationAlert.isAcknowledged(), "Alert is acknowledged")
}

@Test
fun `test correlation parse function`() {
// Generate a random CorrelationAlert object
val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE)
val correlationAlertString = getJsonString(correlationAlert)

// Convert the JSON string to a BytesReference
val serializedBytes: BytesReference = BytesArray(correlationAlertString.toByteArray(Charsets.UTF_8))

// Deserialize the BytesReference into a CorrelationAlert object using the parse function
val recreatedAlert: CorrelationAlert = InputStreamStreamInput(serializedBytes.streamInput()).use { streamInput ->
XContentHelper.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, serializedBytes).use { parser ->
parser.nextToken() // Move to the start of the content
CorrelationAlert.parse(parser)
}
}

// Assert that the deserialized object matches the original object
assertEquals(correlationAlert.correlatedFindingIds, recreatedAlert.correlatedFindingIds)
assertEquals(correlationAlert.correlationRuleId, recreatedAlert.correlationRuleId)
assertEquals(correlationAlert.correlationRuleName, recreatedAlert.correlationRuleName)
assertEquals(correlationAlert.triggerName, recreatedAlert.triggerName)
assertEquals(correlationAlert.state, recreatedAlert.state)
val expectedStartTime = correlationAlert.startTime.truncatedTo(ChronoUnit.MILLIS)
val actualStartTime = recreatedAlert.startTime.truncatedTo(ChronoUnit.MILLIS)
assertEquals(expectedStartTime, actualStartTime)
assertEquals(correlationAlert.severity, recreatedAlert.severity)
assertEquals(correlationAlert.id, recreatedAlert.id)
assertEquals(correlationAlert.actionExecutionResults, recreatedAlert.actionExecutionResults)
}

@Test
fun `Feature Correlation Alert serialize and deserialize should be equal`() {
val correlationAlert = randomCorrelationAlert("alertId1", Alert.State.ACTIVE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ fun createUnifiedAlertTemplateArgs(unifiedAlert: BaseAlert): Map<String, Any?> {
fun createCorrelationAlertTemplateArgs(correlationAlert: CorrelationAlert): Map<String, Any?> {
val unifiedAlertTemplateArgs = createUnifiedAlertTemplateArgs(correlationAlert)
return unifiedAlertTemplateArgs + mapOf(
"correlatedFindingIds" to correlationAlert.correlatedFindingIds,
"correlationRuleId" to correlationAlert.correlationRuleId,
"correlationRuleName" to correlationAlert.correlationRuleName
CorrelationAlert.CORRELATED_FINDING_IDS to correlationAlert.correlatedFindingIds,
CorrelationAlert.CORRELATION_RULE_ID to correlationAlert.correlationRuleId,
CorrelationAlert.CORRELATION_RULE_NAME to correlationAlert.correlationRuleName
)
}

Expand Down
7 changes: 1 addition & 6 deletions src/test/kotlin/org/opensearch/commons/utils/TestHelpers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.opensearch.commons.utils

import org.opensearch.common.xcontent.XContentFactory
import org.opensearch.common.xcontent.XContentType
import org.opensearch.commons.alerting.model.CorrelationAlert
import org.opensearch.core.xcontent.DeprecationHandler
import org.opensearch.core.xcontent.NamedXContentRegistry
import org.opensearch.core.xcontent.ToXContent
Expand All @@ -17,11 +16,7 @@ import java.io.ByteArrayOutputStream
fun getJsonString(xContent: ToXContent): String {
ByteArrayOutputStream().use { byteArrayOutputStream ->
val builder = XContentFactory.jsonBuilder(byteArrayOutputStream)
if (xContent is CorrelationAlert) {
xContent.toXContent(builder)
} else {
xContent.toXContent(builder, ToXContent.EMPTY_PARAMS)
}
xContent.toXContent(builder, ToXContent.EMPTY_PARAMS)
builder.close()
return byteArrayOutputStream.toString("UTF8")
}
Expand Down

0 comments on commit 480590d

Please sign in to comment.