Skip to content

Commit

Permalink
Write peer storage to DB
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Oct 21, 2024
1 parent 8454406 commit ed48cb0
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 24 deletions.
4 changes: 3 additions & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config,
willFundRates_opt: Option[LiquidityAds.WillFundRates],
peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig,
onTheFlyFundingConfig: OnTheFlyFunding.Config) {
onTheFlyFundingConfig: OnTheFlyFunding.Config,
peerStorageWriteDelayMax: FiniteDuration) {
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey

val nodeId: PublicKey = nodeKeyManager.nodeId
Expand Down Expand Up @@ -680,6 +681,7 @@ object NodeParams extends Logging {
onTheFlyFundingConfig = OnTheFlyFunding.Config(
proposalTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.proposal-timeout").getSeconds, TimeUnit.SECONDS),
),
peerStorageWriteDelayMax = 1 minute,
)
}
}
11 changes: 11 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement}
import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, Paginated, RealShortChannelId, ShortChannelId, TimestampMilli}
import grizzled.slf4j.Logging
import scodec.bits.ByteVector

import java.io.File
import java.util.UUID
Expand Down Expand Up @@ -292,6 +293,16 @@ case class DualPeersDb(primary: PeersDb, secondary: PeersDb) extends PeersDb {
runAsync(secondary.getRelayFees(nodeId))
primary.getRelayFees(nodeId)
}

override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = {
runAsync(secondary.updateStorage(nodeId, data))
primary.updateStorage(nodeId, data)
}

override def getStorage(nodeId: PublicKey): Option[ByteVector] = {
runAsync(secondary.getStorage(nodeId))
primary.getStorage(nodeId)
}
}

case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends PaymentsDb {
Expand Down
5 changes: 5 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package fr.acinq.eclair.db
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair.wire.protocol.NodeAddress
import scodec.bits.ByteVector

trait PeersDb {

Expand All @@ -34,4 +35,8 @@ trait PeersDb {

def getRelayFees(nodeId: PublicKey): Option[RelayFees]

def updateStorage(nodeId: PublicKey, data: ByteVector): Unit

def getStorage(nodeId: PublicKey): Option[ByteVector]

}
47 changes: 43 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.bits.BitVector
import scodec.bits.{BitVector, ByteVector}

import java.sql.Statement
import javax.sql.DataSource

object PgPeersDb {
val DB_NAME = "peers"
val CURRENT_VERSION = 3
val CURRENT_VERSION = 4
}

class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logging {
Expand All @@ -54,20 +54,28 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)")
}

def migration34(statement: Statement): Unit = {
statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
}

using(pg.createStatement()) { statement =>
getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local")
statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, storage BYTEA)")
statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)")
case Some(v@(1 | 2)) =>
statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
case Some(v@(1 | 2 | 3)) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
if (v < 2) {
migration12(statement)
}
if (v < 3) {
migration23(statement)
}
if (v < 4) {
migration34(statement)
}
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
Expand Down Expand Up @@ -98,6 +106,10 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
statement.setString(1, nodeId.value.toHex)
statement.executeUpdate()
}
using(pg.prepareStatement("DELETE FROM local.peer_storage WHERE node_id = ?")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.executeUpdate()
}
}
}

Expand Down Expand Up @@ -155,4 +167,31 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
}
}
}

override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement(
"""
INSERT INTO local.peer_storage (node_id, data)
VALUES (?, ?)
ON CONFLICT (node_id)
DO UPDATE SET data = EXCLUDED.data
""")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.setBytes(2, data.toArray)
statement.executeUpdate()
}
}
}

override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("SELECT data FROM local.peer_storage WHERE node_id = ?")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.executeQuery()
.headOption
.map(rs => ByteVector(rs.getBytes("data")))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, setVersion, using}
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.bits.BitVector
import scodec.bits.{BitVector, ByteVector}

import java.sql.{Connection, Statement}

object SqlitePeersDb {
val DB_NAME = "peers"
val CURRENT_VERSION = 2
val CURRENT_VERSION = 3
}

class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
Expand All @@ -46,13 +46,23 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)")
}

def migration23(statement: Statement): Unit = {
statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)")
}

getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE TABLE peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)")
case Some(v@1) =>
statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)")
case Some(v@(1 | 2)) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration12(statement)
if (v < 2) {
migration12(statement)
}
if (v < 3) {
migration23(statement)
}
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
Expand Down Expand Up @@ -128,4 +138,27 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
)
}
}

override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Sqlite) {
using(sqlite.prepareStatement("UPDATE peer_storage SET data = ? WHERE node_id = ?")) { update =>
update.setBytes(1, data.toArray)
update.setBytes(2, nodeId.value.toArray)
if (update.executeUpdate() == 0) {
using(sqlite.prepareStatement("INSERT INTO peer_storage VALUES (?, ?)")) { statement =>
statement.setBytes(1, nodeId.value.toArray)
statement.setBytes(2, data.toArray)
statement.executeUpdate()
}
}
}
}

override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Sqlite) {
using(sqlite.prepareStatement("SELECT data FROM peer_storage WHERE node_id = ?")) { statement =>
statement.setBytes(1, nodeId.value.toArray)
statement.executeQuery()
.headOption
.map(rs => ByteVector(rs.getBytes("data")))
}
}
}
33 changes: 25 additions & 8 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure
import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RecommendedFeerates, RoutingMessage, SpliceInit, TlvStream, TxAbort, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
import scodec.bits.ByteVector

/**
Expand Down Expand Up @@ -85,7 +86,7 @@ class Peer(val nodeParams: NodeParams,
FinalChannelId(state.channelId) -> channel
}.toMap
context.system.eventStream.publish(PeerCreated(self, remoteNodeId))
goto(DISCONNECTED) using DisconnectedData(channels, None) // when we restart, we will attempt to reconnect right away, but then we'll wait
goto(DISCONNECTED) using DisconnectedData(channels, PeerStorage(nodeParams.db.peers.getStorage(remoteNodeId), written = true, TimestampMilli.min)) // when we restart, we will attempt to reconnect right away, but then we'll wait
}

when(DISCONNECTED) {
Expand Down Expand Up @@ -515,7 +516,19 @@ class Peer(val nodeParams: NodeParams,
stay()

case Event(store: PeerStorageStore, d: ConnectedData) if nodeParams.features.hasFeature(Features.ProvideStorage) && d.channels.nonEmpty =>
stay() using d.copy(peerStorage = Some(store.blob))
val timeSinceLastWrite = TimestampMilli.now() - d.peerStorage.lastWrite
val peerStorage = if (timeSinceLastWrite >= nodeParams.peerStorageWriteDelayMax) {
nodeParams.db.peers.updateStorage(remoteNodeId, store.blob)
PeerStorage(Some(store.blob), written = true, TimestampMilli.now())
} else {
startSingleTimer("peer-storage-write", WritePeerStorage, nodeParams.peerStorageWriteDelayMax - timeSinceLastWrite)
PeerStorage(Some(store.blob), written = false, d.peerStorage.lastWrite)
}
stay() using d.copy(peerStorage = peerStorage)

case Event(WritePeerStorage, d: ConnectedData) =>
d.peerStorage.data.foreach(nodeParams.db.peers.updateStorage(remoteNodeId, _))
stay() using d.copy(peerStorage = PeerStorage(d.peerStorage.data, written = true, TimestampMilli.now()))

case Event(unhandledMsg: LightningMessage, _) =>
log.warning("ignoring message {}", unhandledMsg)
Expand Down Expand Up @@ -748,7 +761,7 @@ class Peer(val nodeParams: NodeParams,
context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId))
}

private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: Option[ByteVector]): State = {
private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage): State = {
require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeId: $remoteNodeId != ${connectionReady.remoteNodeId}")
log.debug("got authenticated connection to address {}", connectionReady.address)

Expand All @@ -759,7 +772,7 @@ class Peer(val nodeParams: NodeParams,
}

// If we have some data stored from our peer, we send it to them before doing anything else.
peerStorage.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_))
peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_))

// let's bring existing/requested channels online
channels.values.toSet[ActorRef].foreach(_ ! INPUT_RECONNECTED(connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit)) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)
Expand Down Expand Up @@ -913,16 +926,18 @@ object Peer {
case class TemporaryChannelId(id: ByteVector32) extends ChannelId
case class FinalChannelId(id: ByteVector32) extends ChannelId

case class PeerStorage(data: Option[ByteVector], written: Boolean, lastWrite: TimestampMilli)

sealed trait Data {
def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef]
def peerStorage: Option[ByteVector]
def peerStorage: PeerStorage
}
case object Nothing extends Data {
override def channels = Map.empty
override def peerStorage: Option[ByteVector] = None
override def peerStorage: PeerStorage = PeerStorage(None, written = true, TimestampMilli.min)
}
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: Option[ByteVector]) extends Data
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: Option[ByteVector]) extends Data {
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: PeerStorage) extends Data
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage) extends Data {
val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit)
def localFeatures: Features[InitFeature] = localInit.features
def remoteFeatures: Features[InitFeature] = remoteInit.features
Expand Down Expand Up @@ -1035,5 +1050,7 @@ object Peer {
case class RelayOnionMessage(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]])

case class RelayUnknownMessage(unknownMessage: UnknownMessage)

case object WritePeerStorage
// @formatter:on
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ object TestConstants {
Features.StaticRemoteKey -> FeatureSupport.Mandatory,
Features.Quiescence -> FeatureSupport.Optional,
Features.SplicePrototype -> FeatureSupport.Optional,
Features.ProvideStorage -> FeatureSupport.Optional,
),
unknown = Set(UnknownFeature(TestFeature.optional))
),
Expand Down Expand Up @@ -240,6 +241,7 @@ object TestConstants {
willFundRates_opt = Some(defaultLiquidityRates),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds),
peerStorageWriteDelayMax = 5 seconds,
)

def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(
Expand Down Expand Up @@ -416,6 +418,7 @@ object TestConstants {
willFundRates_opt = Some(defaultLiquidityRates),
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds),
peerStorageWriteDelayMax = 5 seconds,
)

def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(
Expand Down
21 changes: 21 additions & 0 deletions eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair._
import fr.acinq.eclair.wire.protocol.{NodeAddress, Tor2, Tor3}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits.HexStringSyntax

import java.util.concurrent.Executors
import scala.concurrent.duration._
Expand Down Expand Up @@ -107,4 +108,24 @@ class PeersDbSpec extends AnyFunSuite {
}
}

test("peer storage") {
forAllDbs { dbs =>
val db = dbs.peers

val a = randomKey().publicKey
val b = randomKey().publicKey

assert(db.getStorage(a) == None)
assert(db.getStorage(b) == None)
db.updateStorage(a, hex"012345")
assert(db.getStorage(a) == Some(hex"012345"))
assert(db.getStorage(b) == None)
db.updateStorage(a, hex"6789")
assert(db.getStorage(a) == Some(hex"6789"))
assert(db.getStorage(b) == None)
db.updateStorage(b, hex"abcd")
assert(db.getStorage(a) == Some(hex"6789"))
assert(db.getStorage(b) == Some(hex"abcd"))
}
}
}
Loading

0 comments on commit ed48cb0

Please sign in to comment.