Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FDP-2758: New SignableMessageWrapper that doesn't need subclassing #31

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MessageSigner(properties: MessageSigningProperties) {
}

fun canSignMessages(): Boolean {
return this.signingEnabled && this.signingKey != null
return signingEnabled && signingKey != null
}

/**
Expand All @@ -62,8 +62,8 @@ class MessageSigner(properties: MessageSigningProperties) {
* @throws UncheckedSecurityException if the signing process throws a SignatureException.
*/
fun <T> signUsingField(message: SignableMessageWrapper<T>): T {
if (this.signingEnabled) {
val signatureBytes = this.signature(message)
if (signingEnabled) {
val signatureBytes = signature(message)
message.setSignature(signatureBytes)
}
return message.message
Expand All @@ -82,8 +82,8 @@ class MessageSigner(properties: MessageSigningProperties) {
fun signUsingHeader(
producerRecord: ProducerRecord<String, out SpecificRecordBase>
): ProducerRecord<String, out SpecificRecordBase> {
if (this.signingEnabled) {
val signature = this.signature(producerRecord)
if (signingEnabled) {
val signature = signature(producerRecord)
producerRecord.headers().add(RECORD_HEADER_KEY_SIGNATURE, signature.array())
}
return producerRecord
Expand All @@ -103,18 +103,18 @@ class MessageSigner(properties: MessageSigningProperties) {
* @throws UncheckedSecurityException if the signing process throws a SignatureException.
*/
private fun signature(message: SignableMessageWrapper<*>): ByteBuffer {
check(this.canSignMessages()) {
check(canSignMessages()) {
"This MessageSigner is not configured for signing, it can only be used for verification"
}
val oldSignature = message.getSignature()
message.setSignature(null)
val byteBuffer = this.toByteBuffer(message)
message.clearSignature()
val byteBuffer = toByteBuffer(message)
try {
return signature(byteBuffer)
} catch (e: SignatureException) {
throw UncheckedSecurityException("Unable to sign message", e)
} finally {
message.setSignature(oldSignature)
oldSignature?.let { message.setSignature(it) }
}
}

Expand All @@ -132,13 +132,13 @@ class MessageSigner(properties: MessageSigningProperties) {
* @throws UncheckedSecurityException if the signing process throws a SignatureException.
*/
private fun signature(producerRecord: ProducerRecord<String, out SpecificRecordBase>): ByteBuffer {
check(this.canSignMessages()) {
check(canSignMessages()) {
"This MessageSigner is not configured for signing, it can only be used for verification"
}
val oldSignatureHeader = producerRecord.headers().lastHeader(RECORD_HEADER_KEY_SIGNATURE)
producerRecord.headers().remove(RECORD_HEADER_KEY_SIGNATURE)
val specificRecordBase = producerRecord.value()
val byteBuffer = this.toByteBuffer(specificRecordBase)
val byteBuffer = toByteBuffer(specificRecordBase)
try {
return signature(byteBuffer)
} catch (e: SignatureException) {
Expand All @@ -152,8 +152,8 @@ class MessageSigner(properties: MessageSigningProperties) {

private fun signature(byteBuffer: ByteBuffer): ByteBuffer {
val messageBytes: ByteBuffer =
if (this.stripAvroHeader) {
this.stripAvroHeader(byteBuffer)
if (stripAvroHeader) {
stripAvroHeader(byteBuffer)
} else {
byteBuffer
}
Expand All @@ -163,7 +163,7 @@ class MessageSigner(properties: MessageSigningProperties) {
}

fun canVerifyMessageSignatures(): Boolean {
return this.signingEnabled && this.verificationKey != null
return signingEnabled && verificationKey != null
}

/**
Expand All @@ -173,7 +173,7 @@ class MessageSigner(properties: MessageSigningProperties) {
* @return `true` if the signature of the given `message` was verified; `false` if not.
*/
fun <T> verifyUsingField(message: SignableMessageWrapper<T>): Boolean {
if (!this.canVerifyMessageSignatures()) {
if (!canVerifyMessageSignatures()) {
logger.error("This MessageSigner is not configured for verification, it can only be used for signing")
return false
}
Expand All @@ -186,8 +186,8 @@ class MessageSigner(properties: MessageSigningProperties) {
}

try {
message.setSignature(null)
return this.verifySignatureBytes(messageSignature, this.toByteBuffer(message))
message.clearSignature()
return verifySignatureBytes(messageSignature, toByteBuffer(message))
} catch (e: Exception) {
logger.error("Unable to verify message signature", e)
return false
Expand All @@ -203,7 +203,7 @@ class MessageSigner(properties: MessageSigningProperties) {
* @return `true` if the signature of the given `consumerRecord` was verified; `false` if not. SignatureException.
*/
fun verifyUsingHeader(consumerRecord: ConsumerRecord<String, out SpecificRecordBase>): Boolean {
if (!this.canVerifyMessageSignatures()) {
if (!canVerifyMessageSignatures()) {
logger.error("This MessageSigner is not configured for verification, it can only be used for signing")
return false
}
Expand All @@ -222,7 +222,7 @@ class MessageSigner(properties: MessageSigningProperties) {

try {
val specificRecordBase: SpecificRecordBase = consumerRecord.value()
return this.verifySignatureBytes(ByteBuffer.wrap(signatureBytes), this.toByteBuffer(specificRecordBase))
return verifySignatureBytes(ByteBuffer.wrap(signatureBytes), toByteBuffer(specificRecordBase))
} catch (e: Exception) {
logger.error("Unable to verify message signature", e)
return false
Expand All @@ -232,8 +232,8 @@ class MessageSigner(properties: MessageSigningProperties) {
@Throws(SignatureException::class)
private fun verifySignatureBytes(signatureBytes: ByteBuffer, messageByteBuffer: ByteBuffer): Boolean {
val messageBytes: ByteBuffer =
if (this.stripAvroHeader) {
this.stripAvroHeader(messageByteBuffer)
if (stripAvroHeader) {
stripAvroHeader(messageByteBuffer)
} else {
messageByteBuffer
}
Expand All @@ -249,7 +249,7 @@ class MessageSigner(properties: MessageSigningProperties) {
}

private fun stripAvroHeader(bytes: ByteBuffer): ByteBuffer {
if (this.hasAvroHeader(bytes)) {
if (hasAvroHeader(bytes)) {
return ByteBuffer.wrap(Arrays.copyOfRange(bytes.array(), AVRO_HEADER_LENGTH, bytes.array().size))
}
return bytes
Expand All @@ -274,11 +274,11 @@ class MessageSigner(properties: MessageSigningProperties) {
override fun toString(): String {
return String.format(
"MessageSigner[algorithm=\"%s\"-\"%s\", provider=\"%s\", sign=%b, verify=%b]",
this.signatureAlgorithm,
this.keyAlgorithm,
this.signatureProvider,
this.canSignMessages(),
this.canVerifyMessageSignatures())
signatureAlgorithm,
keyAlgorithm,
signatureProvider,
canSignMessages(),
canVerifyMessageSignatures())
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,25 @@ import java.nio.ByteBuffer
* Wrapper for signable messages. Because these messages are generated from Avro schemas, they can't be changed. This
* wrapper unifies them for the MessageSigner.
*/
abstract class SignableMessageWrapper<T>(val message: T) {
class SignableMessageWrapper<T>(
val message: T,
private val messageGetter: (T) -> ByteBuffer,
private val signatureGetter: (T) -> ByteBuffer?,
private val signatureSetter: (T, ByteBuffer?) -> Unit,
) {

/** @return ByteBuffer of the whole message */
@Throws(IOException::class) abstract fun toByteBuffer(): ByteBuffer
@Throws(IOException::class) internal fun toByteBuffer(): ByteBuffer = messageGetter(message)

/** @return ByteBuffer of the signature in the message */
abstract fun getSignature(): ByteBuffer?
internal fun getSignature(): ByteBuffer? = signatureGetter(message)

/** @param signature The signature in ByteBuffer form to be set on the message */
abstract fun setSignature(signature: ByteBuffer?)
internal fun setSignature(signature: ByteBuffer) {
signatureSetter(message, signature)
}

internal fun clearSignature() {
signatureSetter(message, null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package com.gxf.utilities.kafka.message.signing

import com.gxf.utilities.kafka.message.wrapper.SignableMessageWrapper
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.security.SecureRandom
import java.util.Random
import java.util.function.Consumer
Expand Down Expand Up @@ -34,16 +33,17 @@ class MessageSignerTest {

@Test
fun signsMessageWithoutSignature() {
val messageWrapper: SignableMessageWrapper<*> = this.messageWrapper()
val messageWrapper: SignableMessageWrapper<TestableMessage> = messageWrapper()

messageSigner.signUsingField(messageWrapper)

assertThat(messageWrapper.getSignature()).isNotNull()
assertThat(messageWrapper.message.signature).isEqualTo(messageWrapper.getSignature())
}

@Test
fun signsRecordHeaderWithoutSignature() {
val record = this.producerRecord()
val record = producerRecord()

messageSigner.signUsingHeader(record)

Expand All @@ -52,8 +52,8 @@ class MessageSignerTest {

@Test
fun signsMessageReplacingSignature() {
val randomSignature = this.randomSignature()
val messageWrapper = this.messageWrapper()
val randomSignature = randomSignature()
val messageWrapper = messageWrapper()
messageWrapper.setSignature(randomSignature)

val actualSignatureBefore = messageWrapper.getSignature()
Expand All @@ -67,8 +67,8 @@ class MessageSignerTest {

@Test
fun signsRecordHeaderReplacingSignature() {
val randomSignature = this.randomSignature()
val record = this.producerRecord()
val randomSignature = randomSignature()
val record = producerRecord()
record.headers().add(MessageSigner.RECORD_HEADER_KEY_SIGNATURE, randomSignature.array())

val actualSignatureBefore = record.headers().lastHeader(MessageSigner.RECORD_HEADER_KEY_SIGNATURE).value()
Expand All @@ -82,7 +82,7 @@ class MessageSignerTest {

@Test
fun verifiesMessagesWithValidSignature() {
val message = this.properlySignedMessage()
val message = properlySignedMessage()

val signatureWasVerified = messageSigner.verifyUsingField(message)

Expand All @@ -91,7 +91,7 @@ class MessageSignerTest {

@Test
fun verifiesRecordsWithValidSignature() {
val signedRecord = this.properlySignedRecord()
val signedRecord = properlySignedRecord()

val result = messageSigner.verifyUsingHeader(signedRecord)

Expand All @@ -100,7 +100,7 @@ class MessageSignerTest {

@Test
fun doesNotVerifyMessagesWithoutSignature() {
val messageWrapper = this.messageWrapper()
val messageWrapper = messageWrapper()

val validSignature = messageSigner.verifyUsingField(messageWrapper)

Expand All @@ -109,7 +109,7 @@ class MessageSignerTest {

@Test
fun doesNotVerifyRecordsWithoutSignature() {
val consumerRecord = this.consumerRecord()
val consumerRecord = consumerRecord()

val validSignature = messageSigner.verifyUsingHeader(consumerRecord)

Expand All @@ -118,8 +118,8 @@ class MessageSignerTest {

@Test
fun doesNotVerifyMessagesWithIncorrectSignature() {
val randomSignature = this.randomSignature()
val messageWrapper = this.messageWrapper(randomSignature)
val randomSignature = randomSignature()
val messageWrapper = messageWrapper(randomSignature)

val validSignature = messageSigner.verifyUsingField(messageWrapper)

Expand All @@ -128,8 +128,8 @@ class MessageSignerTest {

@Test
fun doesNotVerifyRecordsWithIncorrectSignature() {
val consumerRecord = this.consumerRecord()
val randomSignature = this.randomSignature()
val consumerRecord = consumerRecord()
val randomSignature = randomSignature()
consumerRecord.headers().add(MessageSigner.RECORD_HEADER_KEY_SIGNATURE, randomSignature.array())

val validSignature = messageSigner.verifyUsingHeader(consumerRecord)
Expand All @@ -139,7 +139,7 @@ class MessageSignerTest {

@Test
fun verifiesMessagesPreservingTheSignatureAndItsProperties() {
val message = this.properlySignedMessage()
val message = properlySignedMessage()
val originalSignature = message.getSignature()

messageSigner.verifyUsingField(message)
Expand All @@ -157,26 +157,22 @@ class MessageSignerTest {
assertThat(messageSignerSigningDisabled.canVerifyMessageSignatures()).isFalse()
}

private fun messageWrapper(): TestableWrapper {
return TestableWrapper()
private fun messageWrapper(signature: ByteBuffer? = null): SignableMessageWrapper<TestableMessage> {
val testableMessage = TestableMessage(signature = signature)
return SignableMessageWrapper(
testableMessage, TestableMessage::getMsgBytes, TestableMessage::getSigBytes, TestableMessage::setSigBytes)
}

private fun messageWrapper(signature: ByteBuffer): TestableWrapper {
val testableWrapper = TestableWrapper()
testableWrapper.setSignature(signature)
return testableWrapper
}

private fun properlySignedMessage(): TestableWrapper {
val messageWrapper = this.messageWrapper()
private fun properlySignedMessage(): SignableMessageWrapper<TestableMessage> {
val messageWrapper = messageWrapper()
messageSigner.signUsingField(messageWrapper)
return messageWrapper
}

private fun properlySignedRecord(): ConsumerRecord<String, Message> {
val producerRecord = this.producerRecord()
val producerRecord = producerRecord()
messageSigner.signUsingHeader(producerRecord)
return this.producerRecordToConsumerRecord(producerRecord)
return producerRecordToConsumerRecord(producerRecord)
}

private fun <K, V> producerRecordToConsumerRecord(producerRecord: ProducerRecord<K, V>): ConsumerRecord<K, V> {
Expand All @@ -197,11 +193,11 @@ class MessageSignerTest {
}

private fun producerRecord(): ProducerRecord<String, Message> {
return ProducerRecord("topic", this.message())
return ProducerRecord("topic", message())
}

private fun consumerRecord(): ConsumerRecord<String, Message> {
return ConsumerRecord("topic", 0, 123L, null, this.message())
return ConsumerRecord("topic", 0, 123L, null, message())
}

private fun message(): Message {
Expand All @@ -221,23 +217,21 @@ class MessageSignerTest {
}

override fun put(field: Int, value: Any) {
this.message = value.toString()
message = value.toString()
}
}

private class TestableWrapper : SignableMessageWrapper<String>("Some test message") {
private var signature: ByteBuffer? = null
/**
* Object to test the wrapper with. Intentionally chose function names that are different from the ones in the
* wrapper class
*/
private class TestableMessage(var message: ByteBuffer = ByteBuffer.allocate(3), var signature: ByteBuffer? = null) {
fun getMsgBytes(): ByteBuffer = ByteBuffer.wrap(message.array())

override fun toByteBuffer(): ByteBuffer {
return ByteBuffer.wrap(message.toByteArray(StandardCharsets.UTF_8))
}

override fun getSignature(): ByteBuffer? {
return this.signature
}
fun getSigBytes(): ByteBuffer? = signature

override fun setSignature(signature: ByteBuffer?) {
this.signature = signature
fun setSigBytes(newSignature: ByteBuffer?) {
signature = newSignature
}
}
}
Loading