Skip to content

Commit

Permalink
feat: Integrate service messages (#48)
Browse files Browse the repository at this point in the history
* feat: Integrate service messages

Each parcel encapsulates exactly one Relaynet _service message_ encrypted. This fixes #42.

The following changes were also necessary:

- Drop `Message.payload`. Its equivalent in a service message is called _content_.
  - `IncomingMessage.content` supersedes the old `Message.payload`.
  - `OutgoingMessage` doesn't have an equivalent to the old `.payload`
- Capture the service message _type_ in the `Message`s. This type is to be used by app developers to determine how to process the `content`; for example, something like `image/png` could be rendered on screen whilst something like `application/json; type=Tweet` could be JSON-deserialised and verified against a schema.
  - `OutgoingMessage.build()` now requires the `type`.
  - `IncomingMessage.type` exposes the `type`.
- Promote the identity certificate from the children of `ThirdPartyEndpoint` to the parent class. This is needed to encrypt messages regardless of whether the endpoint is private or public.

* fix typo

* remove redundant code
  • Loading branch information
gnarea authored Mar 9, 2021
1 parent c5d661b commit 67c8cb9
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 55 deletions.
2 changes: 1 addition & 1 deletion lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dependencies {
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion"

// Relaynet
api 'tech.relaycorp:relaynet:[1.47.0,2.0.0)'
api 'tech.relaycorp:relaynet:[1.47.1,2.0.0)'
implementation 'tech.relaycorp:poweb:1.5.15'

// Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package tech.relaycorp.relaydroid.endpoint
import tech.relaycorp.relaydroid.Storage
import tech.relaycorp.relaydroid.storage.persistence.PersistenceException
import tech.relaycorp.relaynet.RelaynetException
import tech.relaycorp.relaynet.messages.payloads.ServiceMessage
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import tech.relaycorp.relaynet.wrappers.x509.CertificateException

public sealed class ThirdPartyEndpoint(
override val address: String
override val address: String,
public val identityCertificate: Certificate
) : Endpoint {

public val thirdPartyAddress: String get() = address
Expand All @@ -26,8 +28,8 @@ public class PrivateThirdPartyEndpoint(
public val firstPartyAddress: String,
thirdPartyAddress: String,
public val authorization: Certificate,
public val identity: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress) {
identityCertificate: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress, identityCertificate) {

public companion object {

Expand Down Expand Up @@ -76,8 +78,8 @@ public class PrivateThirdPartyEndpoint(

public class PublicThirdPartyEndpoint(
thirdPartyAddress: String,
public val certificate: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress) {
identityCertificate: Certificate
) : ThirdPartyEndpoint(thirdPartyAddress, identityCertificate) {

public companion object {
@Throws(PersistenceException::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@ import tech.relaycorp.relaydroid.endpoint.UnknownThirdPartyEndpointException
import tech.relaycorp.relaydroid.storage.persistence.PersistenceException
import tech.relaycorp.relaynet.messages.Parcel
import java.time.ZonedDateTime
import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.wrappers.cms.EnvelopedDataException

public class IncomingMessage internal constructor(
id: MessageId,
payload: ByteArray,
public val type: String,
public val content: ByteArray,
public val senderEndpoint: ThirdPartyEndpoint,
public val recipientEndpoint: FirstPartyEndpoint,
creationDate: ZonedDateTime,
expiryDate: ZonedDateTime,
public val ack: suspend () -> Unit
) : Message(
id, payload, senderEndpoint, recipientEndpoint, creationDate, expiryDate
id, senderEndpoint, recipientEndpoint, creationDate, expiryDate
) {

internal companion object {
@Throws(
UnknownFirstPartyEndpointException::class,
UnknownThirdPartyEndpointException::class,
PersistenceException::class
PersistenceException::class,
EnvelopedDataException::class,
InvalidMessageException::class
)
internal suspend fun build(parcel: Parcel, ack: suspend () -> Unit): IncomingMessage {
val recipientEndpoint = FirstPartyEndpoint.load(parcel.recipientAddress)
Expand All @@ -41,9 +46,11 @@ public class IncomingMessage internal constructor(
"for first party endpoint ${parcel.recipientAddress}"
)

val serviceMessage = parcel.unwrapPayload(recipientEndpoint.keyPair.private)
return IncomingMessage(
id = MessageId(parcel.id),
payload = parcel.payload,
type = serviceMessage.type,
content = serviceMessage.content,
senderEndpoint = sender,
recipientEndpoint = recipientEndpoint,
creationDate = parcel.creationDate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import java.time.ZonedDateTime

public abstract class Message(
public val id: MessageId,
public val payload: ByteArray,
senderEndpoint: Endpoint,
recipientEndpoint: Endpoint,
public val creationDate: ZonedDateTime = ZonedDateTime.now(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,36 @@ import tech.relaycorp.relaynet.messages.Parcel
import tech.relaycorp.relaynet.ramf.RAMFException
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import java.time.ZonedDateTime
import tech.relaycorp.relaynet.messages.payloads.ServiceMessage

public class OutgoingMessage
private constructor(
payload: ByteArray,
public val senderEndpoint: FirstPartyEndpoint,
public val recipientEndpoint: ThirdPartyEndpoint,
creationDate: ZonedDateTime = ZonedDateTime.now(),
expiryDate: ZonedDateTime = maxExpiryDate(),
id: MessageId = MessageId.generate()
) : Message(
id, payload, senderEndpoint, recipientEndpoint, creationDate, expiryDate
id, senderEndpoint, recipientEndpoint, creationDate, expiryDate
) {

internal lateinit var parcel: Parcel
private set

public companion object {
public suspend fun build(
payload: ByteArray,
type: String,
content: ByteArray,
senderEndpoint: FirstPartyEndpoint,
recipientEndpoint: ThirdPartyEndpoint,
creationDate: ZonedDateTime = ZonedDateTime.now(),
expiryDate: ZonedDateTime = maxExpiryDate(),
id: MessageId = MessageId.generate()
): OutgoingMessage {
val message = OutgoingMessage(
payload, senderEndpoint, recipientEndpoint, creationDate, expiryDate, id
senderEndpoint, recipientEndpoint, creationDate, expiryDate, id
)
message.parcel = message.buildParcel()
message.parcel = message.buildParcel(type, content)
try {
message.parcel.validate(null)
} catch (exp: RAMFException) {
Expand All @@ -48,19 +49,25 @@ private constructor(
}
}

private suspend fun buildParcel() = Parcel(
recipientAddress = if (recipientEndpoint is PublicThirdPartyEndpoint) {
"https://" + recipientEndpoint.address
} else {
recipientEndpoint.address
},
payload = payload,
senderCertificate = getSenderCertificate(),
messageId = id.value,
creationDate = creationDate,
ttl = ttl,
senderCertificateChain = getSenderCertificateChain()
)
private suspend fun buildParcel(
serviceMessageType: String,
serviceMessageContent: ByteArray
): Parcel {
val serviceMessage = ServiceMessage(serviceMessageType, serviceMessageContent)
return Parcel(
recipientAddress = if (recipientEndpoint is PublicThirdPartyEndpoint) {
"https://" + recipientEndpoint.address
} else {
recipientEndpoint.address
},
payload = serviceMessage.encrypt(recipientEndpoint.identityCertificate),
senderCertificate = getSenderCertificate(),
messageId = id.value,
creationDate = creationDate,
ttl = ttl,
senderCertificateChain = getSenderCertificateChain()
)
}

private suspend fun getSenderCertificate() =
when (recipientEndpoint) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package tech.relaycorp.relaydroid.messaging

import java.util.logging.Level
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.flatMapLatest
Expand All @@ -15,12 +16,13 @@ import tech.relaycorp.relaydroid.storage.persistence.PersistenceException
import tech.relaycorp.relaynet.bindings.pdc.ClientBindingException
import tech.relaycorp.relaynet.bindings.pdc.NonceSignerException
import tech.relaycorp.relaynet.bindings.pdc.PDCClient
import tech.relaycorp.relaynet.bindings.pdc.ParcelCollection
import tech.relaycorp.relaynet.bindings.pdc.ServerException
import tech.relaycorp.relaynet.bindings.pdc.Signer
import tech.relaycorp.relaynet.bindings.pdc.StreamingMode
import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.ramf.RAMFException
import java.util.logging.Level
import tech.relaycorp.relaynet.wrappers.cms.EnvelopedDataException

internal class ReceiveMessages(
private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Relaynet.POWEB_PORT) }
Expand Down Expand Up @@ -69,27 +71,40 @@ internal class ReceiveMessages(
val parcel = try {
parcelCollection.deserializeAndValidateParcel()
} catch (exp: RAMFException) {
logger.log(Level.WARNING, "Malformed incoming parcel", exp)
parcelCollection.ack()
parcelCollection.disregard("Malformed incoming parcel", exp)
return@mapNotNull null
} catch (exp: InvalidMessageException) {
logger.log(Level.WARNING, "Invalid incoming parcel", exp)
parcelCollection.ack()
parcelCollection.disregard("Invalid incoming parcel", exp)
return@mapNotNull null
}
try {
IncomingMessage.build(parcel) { parcelCollection.ack() }
} catch (exp: UnknownFirstPartyEndpointException) {
logger.log(Level.WARNING, "Incoming parcel with invalid recipient", exp)
parcelCollection.ack()
parcelCollection.disregard("Incoming parcel with invalid recipient", exp)
return@mapNotNull null
} catch (exp: UnknownFirstPartyEndpointException) {
logger.log(Level.WARNING, "Incoming parcel issues with invalid sender", exp)
parcelCollection.ack()
parcelCollection.disregard("Incoming parcel issues with invalid sender", exp)
return@mapNotNull null
} catch (exp: EnvelopedDataException) {
parcelCollection.disregard(
"Failed to decrypt parcel; sender might have used wrong key",
exp
)
return@mapNotNull null
} catch (exp: InvalidMessageException) {
parcelCollection.disregard(
"Incoming parcel did not encapsulate a valid service message",
exp
)
return@mapNotNull null
}
}
}

private suspend fun ParcelCollection.disregard(reason: String, exc: Throwable) {
logger.log(Level.WARNING, reason, exc)
ack()
}

public class ReceiveMessagesException(message: String, throwable: Throwable? = null)
: GatewayException(message, throwable)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal class PrivateThirdPartyEndpointTest {
assertEquals(firstAddress, firstPartyAddress)
assertEquals(thirdAddress, address)
assertEquals(PDACertPath.PRIVATE_ENDPOINT, authorization)
assertEquals(PDACertPath.PRIVATE_ENDPOINT, identity)
assertEquals(PDACertPath.PRIVATE_ENDPOINT, identityCertificate)
}

verify(storage.thirdPartyAuthorization).get("${firstAddress}_$thirdAddress")
Expand Down Expand Up @@ -93,7 +93,7 @@ internal class PrivateThirdPartyEndpointTest {
)
assertEquals(
PDACertPath.PRIVATE_ENDPOINT,
endpoint.identity
endpoint.identityCertificate
)

verify(storage.identityCertificate).get(firstPartyAddress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal class PublicThirdPartyEndpointTest {

with(PublicThirdPartyEndpoint.load(address)!!) {
assertEquals(address, this.address)
assertEquals(PDACertPath.PUBLIC_GW, certificate)
assertEquals(PDACertPath.PUBLIC_GW, identityCertificate)
}
}

Expand All @@ -49,7 +49,7 @@ internal class PublicThirdPartyEndpointTest {
fun import_successful() = runBlockingTest {
with(PublicThirdPartyEndpoint.import(PDACertPath.PUBLIC_GW)) {
assertEquals(address, this.address)
assertEquals(PDACertPath.PUBLIC_GW, certificate)
assertEquals(PDACertPath.PUBLIC_GW, identityCertificate)
}

verify(storage.publicThirdPartyCertificate).set(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import tech.relaycorp.relaynet.messages.Parcel
import tech.relaycorp.relaynet.testing.pki.KeyPairSet
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import java.util.UUID
import tech.relaycorp.relaynet.messages.payloads.ServiceMessage

internal class IncomingMessageTest {

Expand All @@ -35,18 +36,20 @@ internal class IncomingMessageTest {

@Test
fun buildFromParcel() = runBlockingTest {
val serviceMessage = ServiceMessage("the type", "the content".toByteArray())
val parcel = Parcel(
recipientAddress = UUID.randomUUID().toString(),
payload = "1234".toByteArray(),
senderCertificate = PDACertPath.PRIVATE_ENDPOINT
payload = serviceMessage.encrypt(PDACertPath.PRIVATE_ENDPOINT),
senderCertificate = PDACertPath.PDA
)

val message = IncomingMessage.build(parcel) {}

verify(Relaynet.storage.identityCertificate).get(eq(parcel.recipientAddress))

assertEquals(PDACertPath.PRIVATE_ENDPOINT, message.recipientEndpoint.identityCertificate)
assertArrayEquals(parcel.payload, message.payload)
assertEquals(serviceMessage.type, message.type)
assertArrayEquals(serviceMessage.content, message.content)
assertEquals(parcel.id, message.id.value)
assertSameDateTime(parcel.creationDate, message.creationDate)
assertSameDateTime(parcel.expiryDate, message.expiryDate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ internal class MessageTest {
fun ttl() = runBlockingTest {
val creationDate = ZonedDateTime.now()
val message = OutgoingMessage.build(
"the type",
Random.Default.nextBytes(10),
senderEndpoint = senderEndpoint,
recipientEndpoint = recipientEndpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.ramf.RecipientAddressType
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import java.time.ZonedDateTime
import tech.relaycorp.relaynet.testing.pki.KeyPairSet

internal class OutgoingMessageTest {

@Test(expected = InvalidMessageException::class)
internal fun buildInvalidMessage() = runBlockingTest {
OutgoingMessage.build(
"the type",
ByteArray(0),
FirstPartyEndpointFactory.build(),
PublicThirdPartyEndpoint("example.org", PDACertPath.PUBLIC_GW),
Expand All @@ -27,17 +29,26 @@ internal class OutgoingMessageTest {
}

@Test
internal fun buildForPublicRecipient_checkBaseValues() = runBlockingTest {
fun buildForPublicRecipient_checkBaseValues() = runBlockingTest {
val message = MessageFactory.buildOutgoing(RecipientAddressType.PUBLIC)
val parcel = message.parcel

assertEquals("https://" + message.recipientEndpoint.address, parcel.recipientAddress)
assertArrayEquals(message.payload, parcel.payload)
assertEquals(message.id.value, parcel.id)
assertSameDateTime(message.creationDate, parcel.creationDate)
assertEquals(message.ttl, parcel.ttl)
}

@Test
fun buildForPublicRecipient_checkServiceMessage() = runBlockingTest {
val message = MessageFactory.buildOutgoing(RecipientAddressType.PUBLIC)
val parcel = message.parcel

val serviceMessageDecrypted = parcel.unwrapPayload(KeyPairSet.PUBLIC_GW.private)
assertEquals(MessageFactory.serviceMessage.type, serviceMessageDecrypted.type)
assertArrayEquals(MessageFactory.serviceMessage.content, serviceMessageDecrypted.content)
}

@Test
internal fun buildForPublicRecipient_checkSenderCertificate() = runBlockingTest {
val message = MessageFactory.buildOutgoing(RecipientAddressType.PUBLIC)
Expand Down
Loading

0 comments on commit 67c8cb9

Please sign in to comment.