Skip to content

Commit

Permalink
Make payment secret not optional (#2457)
Browse files Browse the repository at this point in the history
Payment secret is already required for doing anything but the code was still using an option.
  • Loading branch information
thomash-acinq authored Oct 18, 2022
1 parent dad0a51 commit 3b12475
Show file tree
Hide file tree
Showing 19 changed files with 157 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ case class Bolt11Invoice(prefix: String, amount_opt: Option[MilliSatoshi], creat
amount_opt.foreach(a => require(a > 0.msat, s"amount is not valid"))
require(tags.collect { case _: Bolt11Invoice.PaymentHash => }.size == 1, "there must be exactly one payment hash tag")
require(tags.collect { case Bolt11Invoice.Description(_) | Bolt11Invoice.DescriptionHash(_) => }.size == 1, "there must be exactly one description tag or one description hash tag")
require(tags.collect { case _: Bolt11Invoice.PaymentSecret => }.size == 1, "there must be exactly one payment secret tag")

{
val featuresErr = Features.validateFeatureGraph(features)
Expand All @@ -63,7 +64,7 @@ case class Bolt11Invoice(prefix: String, amount_opt: Option[MilliSatoshi], creat
/**
* @return the payment secret
*/
lazy val paymentSecret = tags.collectFirst { case p: Bolt11Invoice.PaymentSecret => p.secret }
lazy val paymentSecret = tags.collectFirst { case p: Bolt11Invoice.PaymentSecret => p.secret }.get

/**
* @return the description of the payment, or its hash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import fr.acinq.eclair.crypto.Sphinx.RouteBlinding
import fr.acinq.eclair.wire.protocol.OfferTypes._
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.{OfferCodecs, OfferTypes, TlvStream}
import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64}
import fr.acinq.eclair.{CltvExpiryDelta, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, randomBytes32}
import scodec.bits.ByteVector

import java.util.concurrent.TimeUnit
Expand All @@ -43,7 +43,7 @@ case class Bolt12Invoice(records: TlvStream[InvoiceTlv]) extends Invoice {
override val amount_opt: Option[MilliSatoshi] = Some(amount)
override val nodeId: Crypto.PublicKey = records.get[NodeId].get.publicKey
override val paymentHash: ByteVector32 = records.get[PaymentHash].get.hash
override val paymentSecret: Option[ByteVector32] = None
override val paymentSecret: ByteVector32 = randomBytes32()
override val paymentMetadata: Option[ByteVector] = None
override val description: Either[String, ByteVector32] = Left(records.get[Description].get.description)
override val extraEdges: Seq[Invoice.ExtraEdge] = Seq.empty // TODO: the blinded paths need to be converted to graph edges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait Invoice {

val paymentHash: ByteVector32

val paymentSecret: Option[ByteVector32]
val paymentSecret: ByteVector32

val paymentMetadata: Option[ByteVector]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ object MultiPartHandler {
}

private def validatePaymentSecret(add: UpdateAddHtlc, payload: FinalPayload.Standard, invoice: Bolt11Invoice)(implicit log: LoggingAdapter): Boolean = {
if (payload.amount < payload.totalAmount && !invoice.paymentSecret.contains(payload.paymentSecret)) {
if (payload.amount < payload.totalAmount && invoice.paymentSecret != payload.paymentSecret) {
log.warning("received multi-part payment with invalid secret={} for amount={} totalAmount={}", payload.paymentSecret, add.amountMsat, payload.totalAmount)
false
} else if (!invoice.paymentSecret.contains(payload.paymentSecret)) {
} else if (invoice.paymentSecret != payload.paymentSecret) {
log.warning("received payment with invalid secret={} for amount={} totalAmount={}", payload.paymentSecret, add.amountMsat, payload.totalAmount)
false
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ object PaymentError {
sealed trait InvalidInvoice extends PaymentError
/** The invoice contains a feature we don't support. */
case class UnsupportedFeatures(features: Features[InvoiceFeature]) extends InvalidInvoice { override def getMessage: String = s"unsupported invoice features: ${features.toByteVector.toHex}" }
/** The invoice is missing a payment secret. */
case object PaymentSecretMissing extends InvalidInvoice { override def getMessage: String = "invalid invoice: payment secret is missing" }
// @formatter:on

// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,17 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
}
val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = true, Nil)
val finalExpiry = r.finalExpiry(nodeParams.currentBlockHeight)
r.invoice.paymentSecret match {
case _ if !nodeParams.features.invoiceFeatures().areSupported(r.invoice.features) =>
sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(r.invoice.features)) :: Nil)
case None =>
sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, PaymentSecretMissing) :: Nil)
case Some(paymentSecret) if r.invoice.features.hasFeature(Features.BasicMultiPartPayment) && nodeParams.features.hasFeature(BasicMultiPartPayment) =>
val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg)
fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.invoice.paymentMetadata, r.invoice.extraEdges, r.routeParams, userCustomTlvs = r.userCustomTlvs)
context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r)))
case Some(paymentSecret) =>
val finalPayload = FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, finalExpiry, paymentSecret, r.invoice.paymentMetadata, r.userCustomTlvs)
val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.maxAttempts, r.invoice.extraEdges, r.routeParams)
context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r)))
if (!nodeParams.features.invoiceFeatures().areSupported(r.invoice.features)) {
sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, UnsupportedFeatures(r.invoice.features)) :: Nil)
} else if (Features.canUseFeature(nodeParams.features.invoiceFeatures(), r.invoice.features, Features.BasicMultiPartPayment)) {
val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg)
fsm ! MultiPartPaymentLifecycle.SendMultiPartPayment(self, r.invoice.paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.invoice.paymentMetadata, r.invoice.extraEdges, r.routeParams, userCustomTlvs = r.userCustomTlvs)
context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r)))
} else {
val finalPayload = FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, finalExpiry, r.invoice.paymentSecret, r.invoice.paymentMetadata, r.userCustomTlvs)
val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
fsm ! PaymentLifecycle.SendPaymentToNode(self, r.recipientNodeId, finalPayload, r.maxAttempts, r.invoice.extraEdges, r.routeParams)
context become main(pending + (paymentId -> PendingPaymentToNode(sender(), r)))
}

case r: SendSpontaneousPayment =>
Expand Down Expand Up @@ -104,8 +101,6 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
val additionalHops = r.trampolineNodes.sliding(2).map(hop => NodeHop(hop.head, hop(1), CltvExpiryDelta(0), 0 msat)).toSeq
val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false, additionalHops)
r.trampolineNodes match {
case _ if r.invoice.paymentSecret.isEmpty =>
sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, PaymentSecretMissing) :: Nil)
case trampoline :: recipient :: Nil =>
log.info(s"sending trampoline payment to $recipient with trampoline=$trampoline, trampoline fees=${r.trampolineFees}, expiry delta=${r.trampolineExpiryDelta}")
buildTrampolinePayment(r, trampoline, r.trampolineFees, r.trampolineExpiryDelta) match {
Expand All @@ -123,7 +118,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
case Nil =>
sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None)
val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), FinalPayload.Standard.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, r.invoice.paymentSecret.get, r.invoice.paymentMetadata), r.invoice.extraEdges)
payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), FinalPayload.Standard.createMultiPartPayload(r.amount, r.recipientAmount, finalExpiry, r.invoice.paymentSecret, r.invoice.paymentMetadata), r.invoice.extraEdges)
context become main(pending + (paymentId -> PendingPaymentToRoute(sender(), r)))
case _ =>
sender() ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(r.recipientAmount, Nil, TrampolineMultiNodeNotSupported) :: Nil)
Expand Down Expand Up @@ -197,9 +192,9 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
NodeHop(trampolineNodeId, r.recipientNodeId, trampolineExpiryDelta, trampolineFees) // for now we only use a single trampoline hop
)
val finalPayload = if (r.invoice.features.hasFeature(Features.BasicMultiPartPayment)) {
FinalPayload.Standard.createMultiPartPayload(r.recipientAmount, r.recipientAmount, r.finalExpiry(nodeParams.currentBlockHeight), r.invoice.paymentSecret.get, r.invoice.paymentMetadata)
FinalPayload.Standard.createMultiPartPayload(r.recipientAmount, r.recipientAmount, r.finalExpiry(nodeParams.currentBlockHeight), r.invoice.paymentSecret, r.invoice.paymentMetadata)
} else {
FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, r.finalExpiry(nodeParams.currentBlockHeight), r.invoice.paymentSecret.get, r.invoice.paymentMetadata)
FinalPayload.Standard.createSinglePartPayload(r.recipientAmount, r.finalExpiry(nodeParams.currentBlockHeight), r.invoice.paymentSecret, r.invoice.paymentMetadata)
}
// We assume that the trampoline node supports multi-part payments (it should).
val trampolinePacket_opt = if (r.invoice.features.hasFeature(Features.TrampolinePaymentPrototype)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ object PaymentOnion {
val tlvs = Seq(
Some(AmountToForward(amount)),
Some(OutgoingCltv(expiry)),
invoice.paymentSecret.map(s => PaymentData(s, totalAmount)),
Some(PaymentData(invoice.paymentSecret, totalAmount)),
invoice.paymentMetadata.map(m => PaymentMetadata(m)),
Some(OutgoingNodeId(targetNodeId)),
Some(InvoiceFeatures(invoice.features.toByteVector)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I

// with finalCltvExpiry
val externalId2 = "487da196-a4dc-4b1e-92b4-3e5e905e9f3f"
val invoice2 = Bolt11Invoice("lntb", Some(123 msat), TimestampSecond.now(), nodePrivKey.publicKey, List(Bolt11Invoice.MinFinalCltvExpiry(96), Bolt11Invoice.PaymentHash(ByteVector32.Zeroes), Bolt11Invoice.Description("description")), ByteVector.empty)
val invoice2 = Bolt11Invoice("lntb", Some(123 msat), TimestampSecond.now(), nodePrivKey.publicKey, List(Bolt11Invoice.MinFinalCltvExpiry(96), Bolt11Invoice.PaymentHash(ByteVector32.Zeroes), Bolt11Invoice.Description("description"), Bolt11Invoice.PaymentSecret(ByteVector32.One)), ByteVector.empty)
eclair.send(Some(externalId2), 123 msat, invoice2)
val send2 = paymentInitiator.expectMsgType[SendPaymentToNode]
assert(send2.externalId.contains(externalId2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe
paymentHandler ! ReceiveStandardPayment(Some(requiredAmount), Left("One coffee"))
context become {
case req: Invoice =>
sendChannel ! buildCmdAdd(req.paymentHash, req.nodeId, req.paymentSecret.get)
sendChannel ! buildCmdAdd(req.paymentHash, req.nodeId, req.paymentSecret)
context become {
case RES_SUCCESS(_: CMD_ADD_HTLC, _) => ()
case RES_ADD_SETTLED(_, htlc, _: HtlcResult.Fulfill) =>
Expand Down
Loading

0 comments on commit 3b12475

Please sign in to comment.