Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Integrate service messages #48

Merged
merged 3 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dependencies {
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion"

// Relaynet
api 'tech.relaycorp:relaynet:[1.47.0,2.0.0)'
api 'tech.relaycorp:relaynet:[1.47.1,2.0.0)'
implementation 'tech.relaycorp:poweb:1.5.15'

// Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package tech.relaycorp.relaydroid.endpoint
import tech.relaycorp.relaydroid.Storage
import tech.relaycorp.relaydroid.storage.persistence.PersistenceException
import tech.relaycorp.relaynet.RelaynetException
import tech.relaycorp.relaynet.messages.payloads.ServiceMessage
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import tech.relaycorp.relaynet.wrappers.x509.CertificateException

public sealed class ThirdPartyEndpoint(
override val address: String
override val address: String,
public val identityCertificate: Certificate
) : Endpoint {

public val thirdPartyAddress: String get() = address
Expand All @@ -26,8 +28,8 @@ public class PrivateThirdPartyEndpoint(
public val firstPartyAddress: String,
thirdPartyAddress: String,
public val authorization: Certificate,
public val identity: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress) {
identityCertificate: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress, identityCertificate) {

public companion object {

Expand Down Expand Up @@ -76,8 +78,8 @@ public class PrivateThirdPartyEndpoint(

public class PublicThirdPartyEndpoint(
thirdPartyAddress: String,
public val certificate: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress) {
identityCertificate: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress, identityCertificate) {

public companion object {
@Throws(PersistenceException::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@ import tech.relaycorp.relaydroid.endpoint.UnknownThirdPartyEndpointException
import tech.relaycorp.relaydroid.storage.persistence.PersistenceException
import tech.relaycorp.relaynet.messages.Parcel
import java.time.ZonedDateTime
import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.wrappers.cms.EnvelopedDataException

public class IncomingMessage internal constructor(
id: MessageId,
payload: ByteArray,
public val type: String,
public val content: ByteArray,
public val senderEndpoint: ThirdPartyEndpoint,
public val recipientEndpoint: FirstPartyEndpoint,
creationDate: ZonedDateTime,
expiryDate: ZonedDateTime,
public val ack: suspend () -> Unit
) : Message(
id, payload, senderEndpoint, recipientEndpoint, creationDate, expiryDate
id, senderEndpoint, recipientEndpoint, creationDate, expiryDate
) {

internal companion object {
@Throws(
UnknownFirstPartyEndpointException::class,
UnknownThirdPartyEndpointException::class,
PersistenceException::class
PersistenceException::class,
EnvelopedDataException::class,
InvalidMessageException::class
)
internal suspend fun build(parcel: Parcel, ack: suspend () -> Unit): IncomingMessage {
val recipientEndpoint = FirstPartyEndpoint.load(parcel.recipientAddress)
Expand All @@ -41,9 +46,11 @@ public class IncomingMessage internal constructor(
"for first party endpoint ${parcel.recipientAddress}"
)

val serviceMessage = parcel.unwrapPayload(recipientEndpoint.keyPair.private)
return IncomingMessage(
id = MessageId(parcel.id),
payload = parcel.payload,
type = serviceMessage.type,
content = serviceMessage.content,
senderEndpoint = sender,
recipientEndpoint = recipientEndpoint,
creationDate = parcel.creationDate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import java.time.ZonedDateTime

public abstract class Message(
public val id: MessageId,
public val payload: ByteArray,
senderEndpoint: Endpoint,
recipientEndpoint: Endpoint,
public val creationDate: ZonedDateTime = ZonedDateTime.now(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,36 @@ import tech.relaycorp.relaynet.messages.Parcel
import tech.relaycorp.relaynet.ramf.RAMFException
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import java.time.ZonedDateTime
import tech.relaycorp.relaynet.messages.payloads.ServiceMessage

public class OutgoingMessage
private constructor(
payload: ByteArray,
public val senderEndpoint: FirstPartyEndpoint,
public val recipientEndpoint: ThirdPartyEndpoint,
creationDate: ZonedDateTime = ZonedDateTime.now(),
expiryDate: ZonedDateTime = maxExpiryDate(),
id: MessageId = MessageId.generate()
) : Message(
id, payload, senderEndpoint, recipientEndpoint, creationDate, expiryDate
id, senderEndpoint, recipientEndpoint, creationDate, expiryDate
) {

internal lateinit var parcel: Parcel
private set

public companion object {
public suspend fun build(
payload: ByteArray,
type: String,
content: ByteArray,
senderEndpoint: FirstPartyEndpoint,
recipientEndpoint: ThirdPartyEndpoint,
creationDate: ZonedDateTime = ZonedDateTime.now(),
expiryDate: ZonedDateTime = maxExpiryDate(),
id: MessageId = MessageId.generate()
): OutgoingMessage {
val message = OutgoingMessage(
payload, senderEndpoint, recipientEndpoint, creationDate, expiryDate, id
senderEndpoint, recipientEndpoint, creationDate, expiryDate, id
)
message.parcel = message.buildParcel()
message.parcel = message.buildParcel(type, content)
try {
message.parcel.validate(null)
} catch (exp: RAMFException) {
Expand All @@ -48,19 +49,25 @@ private constructor(
}
}

private suspend fun buildParcel() = Parcel(
recipientAddress = if (recipientEndpoint is PublicThirdPartyEndpoint) {
"https://" + recipientEndpoint.address
} else {
recipientEndpoint.address
},
payload = payload,
senderCertificate = getSenderCertificate(),
messageId = id.value,
creationDate = creationDate,
ttl = ttl,
senderCertificateChain = getSenderCertificateChain()
)
private suspend fun buildParcel(
serviceMessageType: String,
serviceMessageContent: ByteArray
): Parcel {
val serviceMessage = ServiceMessage(serviceMessageType, serviceMessageContent)
return Parcel(
recipientAddress = if (recipientEndpoint is PublicThirdPartyEndpoint) {
"https://" + recipientEndpoint.address
} else {
recipientEndpoint.address
},
payload = serviceMessage.encrypt(recipientEndpoint.identityCertificate),
senderCertificate = getSenderCertificate(),
messageId = id.value,
creationDate = creationDate,
ttl = ttl,
senderCertificateChain = getSenderCertificateChain()
)
}

private suspend fun getSenderCertificate() =
when (recipientEndpoint) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package tech.relaycorp.relaydroid.messaging

import java.util.logging.Level
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.flatMapLatest
Expand All @@ -15,12 +16,13 @@ import tech.relaycorp.relaydroid.storage.persistence.PersistenceException
import tech.relaycorp.relaynet.bindings.pdc.ClientBindingException
import tech.relaycorp.relaynet.bindings.pdc.NonceSignerException
import tech.relaycorp.relaynet.bindings.pdc.PDCClient
import tech.relaycorp.relaynet.bindings.pdc.ParcelCollection
import tech.relaycorp.relaynet.bindings.pdc.ServerException
import tech.relaycorp.relaynet.bindings.pdc.Signer
import tech.relaycorp.relaynet.bindings.pdc.StreamingMode
import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.ramf.RAMFException
import java.util.logging.Level
import tech.relaycorp.relaynet.wrappers.cms.EnvelopedDataException

internal class ReceiveMessages(
private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Relaynet.POWEB_PORT) }
Expand Down Expand Up @@ -69,27 +71,40 @@ internal class ReceiveMessages(
val parcel = try {
parcelCollection.deserializeAndValidateParcel()
} catch (exp: RAMFException) {
logger.log(Level.WARNING, "Malformed incoming parcel", exp)
parcelCollection.ack()
parcelCollection.disregard("Malformed incoming parcel", exp)
return@mapNotNull null
} catch (exp: InvalidMessageException) {
logger.log(Level.WARNING, "Invalid incoming parcel", exp)
parcelCollection.ack()
parcelCollection.disregard("Invalid incoming parcel", exp)
return@mapNotNull null
}
try {
IncomingMessage.build(parcel) { parcelCollection.ack() }
} catch (exp: UnknownFirstPartyEndpointException) {
logger.log(Level.WARNING, "Incoming parcel with invalid recipient", exp)
parcelCollection.ack()
parcelCollection.disregard("Incoming parcel with invalid recipient", exp)
return@mapNotNull null
} catch (exp: UnknownFirstPartyEndpointException) {
logger.log(Level.WARNING, "Incoming parcel issues with invalid sender", exp)
parcelCollection.ack()
parcelCollection.disregard("Incoming parcel issues with invalid sender", exp)
return@mapNotNull null
} catch (exp: EnvelopedDataException) {
parcelCollection.disregard(
"Failed to decrypt parcel; sender might have used wrong key",
exp
)
return@mapNotNull null
} catch (exp: InvalidMessageException) {
parcelCollection.disregard(
"Incoming parcel did not encapsulate a valid service message",
exp
)
return@mapNotNull null
}
}
}

private suspend fun ParcelCollection.disregard(reason: String, exc: Throwable) {
logger.log(Level.WARNING, reason, exc)
ack()
}

public class ReceiveMessagesException(message: String, throwable: Throwable? = null)
: GatewayException(message, throwable)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal class PrivateThirdPartyEndpointTest {
assertEquals(firstAddress, firstPartyAddress)
assertEquals(thirdAddress, address)
assertEquals(PDACertPath.PRIVATE_ENDPOINT, authorization)
assertEquals(PDACertPath.PRIVATE_ENDPOINT, identity)
assertEquals(PDACertPath.PRIVATE_ENDPOINT, identityCertificate)
}

verify(storage.thirdPartyAuthorization).get("${firstAddress}_$thirdAddress")
Expand Down Expand Up @@ -93,7 +93,7 @@ internal class PrivateThirdPartyEndpointTest {
)
assertEquals(
PDACertPath.PRIVATE_ENDPOINT,
endpoint.identity
endpoint.identityCertificate
)

verify(storage.identityCertificate).get(firstPartyAddress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal class PublicThirdPartyEndpointTest {

with(PublicThirdPartyEndpoint.load(address)!!) {
assertEquals(address, this.address)
assertEquals(PDACertPath.PUBLIC_GW, certificate)
assertEquals(PDACertPath.PUBLIC_GW, identityCertificate)
}
}

Expand All @@ -49,7 +49,7 @@ internal class PublicThirdPartyEndpointTest {
fun import_successful() = runBlockingTest {
with(PublicThirdPartyEndpoint.import(PDACertPath.PUBLIC_GW)) {
assertEquals(address, this.address)
assertEquals(PDACertPath.PUBLIC_GW, certificate)
assertEquals(PDACertPath.PUBLIC_GW, identityCertificate)
}

verify(storage.publicThirdPartyCertificate).set(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import tech.relaycorp.relaynet.messages.Parcel
import tech.relaycorp.relaynet.testing.pki.KeyPairSet
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import java.util.UUID
import tech.relaycorp.relaynet.messages.payloads.ServiceMessage

internal class IncomingMessageTest {

Expand All @@ -35,18 +36,20 @@ internal class IncomingMessageTest {

@Test
fun buildFromParcel() = runBlockingTest {
val serviceMessage = ServiceMessage("the type", "the content".toByteArray())
val parcel = Parcel(
recipientAddress = UUID.randomUUID().toString(),
payload = "1234".toByteArray(),
senderCertificate = PDACertPath.PRIVATE_ENDPOINT
payload = serviceMessage.encrypt(PDACertPath.PRIVATE_ENDPOINT),
senderCertificate = PDACertPath.PDA
)

val message = IncomingMessage.build(parcel) {}

verify(Relaynet.storage.identityCertificate).get(eq(parcel.recipientAddress))

assertEquals(PDACertPath.PRIVATE_ENDPOINT, message.recipientEndpoint.identityCertificate)
assertArrayEquals(parcel.payload, message.payload)
assertEquals(serviceMessage.type, message.type)
assertArrayEquals(serviceMessage.content, message.content)
assertEquals(parcel.id, message.id.value)
assertSameDateTime(parcel.creationDate, message.creationDate)
assertSameDateTime(parcel.expiryDate, message.expiryDate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ internal class MessageTest {
fun ttl() = runBlockingTest {
val creationDate = ZonedDateTime.now()
val message = OutgoingMessage.build(
"the type",
Random.Default.nextBytes(10),
senderEndpoint = senderEndpoint,
recipientEndpoint = recipientEndpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.ramf.RecipientAddressType
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import java.time.ZonedDateTime
import tech.relaycorp.relaynet.testing.pki.KeyPairSet

internal class OutgoingMessageTest {

@Test(expected = InvalidMessageException::class)
internal fun buildInvalidMessage() = runBlockingTest {
OutgoingMessage.build(
"the type",
ByteArray(0),
FirstPartyEndpointFactory.build(),
PublicThirdPartyEndpoint("example.org", PDACertPath.PUBLIC_GW),
Expand All @@ -27,17 +29,26 @@ internal class OutgoingMessageTest {
}

@Test
internal fun buildForPublicRecipient_checkBaseValues() = runBlockingTest {
fun buildForPublicRecipient_checkBaseValues() = runBlockingTest {
val message = MessageFactory.buildOutgoing(RecipientAddressType.PUBLIC)
val parcel = message.parcel

assertEquals("https://" + message.recipientEndpoint.address, parcel.recipientAddress)
assertArrayEquals(message.payload, parcel.payload)
assertEquals(message.id.value, parcel.id)
assertSameDateTime(message.creationDate, parcel.creationDate)
assertEquals(message.ttl, parcel.ttl)
}

@Test
fun buildForPublicRecipient_checkServiceMessage() = runBlockingTest {
val message = MessageFactory.buildOutgoing(RecipientAddressType.PUBLIC)
val parcel = message.parcel

val serviceMessageDecrypted = parcel.unwrapPayload(KeyPairSet.PUBLIC_GW.private)
assertEquals(MessageFactory.serviceMessage.type, serviceMessageDecrypted.type)
assertArrayEquals(MessageFactory.serviceMessage.content, serviceMessageDecrypted.content)
}

@Test
internal fun buildForPublicRecipient_checkSenderCertificate() = runBlockingTest {
val message = MessageFactory.buildOutgoing(RecipientAddressType.PUBLIC)
Expand Down
Loading