From 8f7a415f5a83aa2989e6ba1fc6641e97559964ce Mon Sep 17 00:00:00 2001 From: Pierre-Marie Padiou Date: Wed, 28 Aug 2019 16:58:49 +0200 Subject: [PATCH] Rework router data structures (#902) Instead of using two separate maps (for channels and channel_updates), we now use a single map, which groups channel+channel_updates. This is also true for data storage, resulting in the removal of the channel_updates table. --- .../fr/acinq/eclair/DBCompatChecker.scala | 2 +- .../scala/fr/acinq/eclair/db/NetworkDb.scala | 20 +- .../eclair/db/sqlite/SqliteNetworkDb.scala | 97 +++--- .../acinq/eclair/db/sqlite/SqliteUtils.scala | 2 + .../main/scala/fr/acinq/eclair/io/Peer.scala | 6 +- .../fr/acinq/eclair/payment/Autoprobe.scala | 7 +- .../scala/fr/acinq/eclair/router/Graph.scala | 32 +- .../scala/fr/acinq/eclair/router/Router.scala | 320 +++++++++--------- .../eclair/wire/LightningMessageTypes.scala | 5 +- .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 4 +- .../acinq/eclair/db/SqliteNetworkDbSpec.scala | 118 +++++-- .../eclair/integration/IntegrationSpec.scala | 18 +- .../scala/fr/acinq/eclair/io/PeerSpec.scala | 6 +- .../router/ChannelRangeQueriesSpec.scala | 37 +- .../fr/acinq/eclair/router/GraphSpec.scala | 2 +- .../eclair/router/RouteCalculationSpec.scala | 84 +++-- .../fr/acinq/eclair/router/RouterSpec.scala | 2 +- .../acinq/eclair/router/RoutingSyncSpec.scala | 79 ++--- .../gui/controllers/AboutController.scala | 3 +- .../controllers/ChannelPaneController.scala | 8 +- .../gui/controllers/NodeInfoController.scala | 7 +- .../controllers/NotificationsController.scala | 5 +- .../ReceivePaymentController.scala | 11 +- .../controllers/SendPaymentController.scala | 9 +- .../gui/controllers/SplashController.scala | 3 +- 25 files changed, 486 insertions(+), 401 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/DBCompatChecker.scala b/eclair-core/src/main/scala/fr/acinq/eclair/DBCompatChecker.scala index b6f9c50da7..9a3b7ef8e3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/DBCompatChecker.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/DBCompatChecker.scala @@ -39,7 +39,7 @@ object DBCompatChecker extends Logging { * @param nodeParams */ def checkNetworkDBCompatibility(nodeParams: NodeParams): Unit = - Try(nodeParams.db.network.listChannels(), nodeParams.db.network.listNodes(), nodeParams.db.network.listChannelUpdates()) match { + Try(nodeParams.db.network.listChannels(), nodeParams.db.network.listNodes()) match { case Success(_) => {} case Failure(_) => throw IncompatibleNetworkDBException } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala index 546516785f..bd72a235ad 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala @@ -19,8 +19,11 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.ShortChannelId +import fr.acinq.eclair.router.PublicChannel import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} +import scala.collection.immutable.SortedMap + trait NetworkDb { def addNode(n: NodeAnnouncement) @@ -35,22 +38,13 @@ trait NetworkDb { def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi) - def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Seq(shortChannelId)) - - /** - * This method removes channel announcements and associated channel updates for a list of channel ids - * - * @param shortChannelIds list of short channel ids - */ - def removeChannels(shortChannelIds: Iterable[ShortChannelId]) + def updateChannel(u: ChannelUpdate) - def listChannels(): Map[ChannelAnnouncement, (ByteVector32, Satoshi)] + def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Set(shortChannelId)) - def addChannelUpdate(u: ChannelUpdate) - - def updateChannelUpdate(u: ChannelUpdate) + def removeChannels(shortChannelIds: Iterable[ShortChannelId]) - def listChannelUpdates(): Seq[ChannelUpdate] + def listChannels(): SortedMap[ShortChannelId, PublicChannel] def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index e5a1a83636..2d2e0a1023 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -21,25 +21,39 @@ import java.sql.Connection import fr.acinq.bitcoin.{ByteVector32, Crypto, Satoshi} import fr.acinq.eclair.ShortChannelId import fr.acinq.eclair.db.NetworkDb -import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.router.PublicChannel import fr.acinq.eclair.wire.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec} import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} -import scodec.bits.BitVector +import grizzled.slf4j.Logging -class SqliteNetworkDb(sqlite: Connection) extends NetworkDb { +import scala.collection.immutable.SortedMap + +class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { import SqliteUtils._ + import SqliteUtils.ExtendedResultSet._ val DB_NAME = "network" - val CURRENT_VERSION = 1 + val CURRENT_VERSION = 2 using(sqlite.createStatement()) { statement => - require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed - statement.execute("PRAGMA foreign_keys = ON") + getVersion(statement, DB_NAME, CURRENT_VERSION) match { + case 1 => + // channel_update are cheap to retrieve, so let's just wipe them out and they'll get resynced + statement.execute("PRAGMA foreign_keys = ON") + logger.warn("migrating network db version 1->2") + statement.executeUpdate("ALTER TABLE channels RENAME COLUMN data TO channel_announcement") + statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_update_1 BLOB NULL") + statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_update_2 BLOB NULL") + statement.executeUpdate("DROP TABLE channel_updates") + statement.execute("PRAGMA foreign_keys = OFF") + setVersion(statement, DB_NAME, CURRENT_VERSION) + logger.warn("migration complete") + case 2 => () // nothing to do + case unknown => throw new IllegalArgumentException(s"unknown version $unknown for network db") + } statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, channel_announcement BLOB NOT NULL, capacity_sat INTEGER NOT NULL, channel_update_1 BLOB NULL, channel_update_2 BLOB NULL)") statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)") } @@ -82,7 +96,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb { } override def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi): Unit = { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO channels VALUES (?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement("INSERT OR IGNORE INTO channels VALUES (?, ?, ?, ?, NULL, NULL)")) { statement => statement.setLong(1, c.shortChannelId.toLong) statement.setString(2, txid.toHex) statement.setBytes(3, channelAnnouncementCodec.encode(c).require.toByteArray) @@ -91,56 +105,39 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb { } } - override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = { - - def removeChannelsInternal(shortChannelIds: Iterable[ShortChannelId]): Unit = { - val ids = shortChannelIds.map(_.toLong).mkString(",") - using(sqlite.createStatement) { statement => - statement.execute("BEGIN TRANSACTION") - statement.executeUpdate(s"DELETE FROM channel_updates WHERE short_channel_id IN ($ids)") - statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id IN ($ids)") - statement.execute("COMMIT TRANSACTION") - } + override def updateChannel(u: ChannelUpdate): Unit = { + val column = if (u.isNode1) "channel_update_1" else "channel_update_2" + using(sqlite.prepareStatement(s"UPDATE channels SET $column=? WHERE short_channel_id=?")) { statement => + statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray) + statement.setLong(2, u.shortChannelId.toLong) + statement.executeUpdate() } - - // remove channels by batch of 1000 - shortChannelIds.grouped(1000).foreach(removeChannelsInternal) } - override def listChannels(): Map[ChannelAnnouncement, (ByteVector32, Satoshi)] = { + override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = { using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT data, txid, capacity_sat FROM channels") - var m: Map[ChannelAnnouncement, (ByteVector32, Satoshi)] = Map() + val rs = statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") + var m = SortedMap.empty[ShortChannelId, PublicChannel] while (rs.next()) { - m += (channelAnnouncementCodec.decode(BitVector(rs.getBytes("data"))).require.value -> - (ByteVector32.fromValidHex(rs.getString("txid")), Satoshi(rs.getLong("capacity_sat")))) + val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value + val txId = ByteVector32.fromValidHex(rs.getString("txid")) + val capacity = rs.getLong("capacity_sat") + val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) + val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) + m = m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt)) } m } } - override def addChannelUpdate(u: ChannelUpdate): Unit = { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO channel_updates VALUES (?, ?, ?)")) { statement => - statement.setLong(1, u.shortChannelId.toLong) - statement.setBoolean(2, Announcements.isNode1(u.channelFlags)) - statement.setBytes(3, channelUpdateCodec.encode(u).require.toByteArray) - statement.executeUpdate() - } - } - - override def updateChannelUpdate(u: ChannelUpdate): Unit = { - using(sqlite.prepareStatement("UPDATE channel_updates SET data=? WHERE short_channel_id=? AND node_flag=?")) { statement => - statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray) - statement.setLong(2, u.shortChannelId.toLong) - statement.setBoolean(3, Announcements.isNode1(u.channelFlags)) - statement.executeUpdate() - } - } - - override def listChannelUpdates(): Seq[ChannelUpdate] = { - using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT data FROM channel_updates") - codecSequence(rs, channelUpdateCodec) + override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = { + using(sqlite.createStatement) { statement => + shortChannelIds + .grouped(1000) // remove channels by batch of 1000 + .foreach {group => + val ids = shortChannelIds.map(_.toLong).mkString(",") + statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id IN ($ids)") + } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index dda0510dee..35af8fc8dd 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -123,6 +123,8 @@ object SqliteUtils { case class ExtendedResultSet(rs: ResultSet) { + def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_)) + def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel)) def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel))) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 695afc06b9..9f002c0c0a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -334,7 +334,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A // we won't clean it up, but we won't remember the temporary id on channel termination stay using d.copy(channels = d.channels + (FinalChannelId(channelId) -> channel)) - case Event(RoutingState(channels, updates, nodes), d: ConnectedData) => + case Event(RoutingState(channels, nodes), d: ConnectedData) => // let's send the messages def send(announcements: Iterable[_ <: LightningMessage]) = announcements.foldLeft(0) { case (c, ann) => @@ -343,9 +343,9 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A } log.info(s"sending all announcements to {}", remoteNodeId) - val channelsSent = send(channels) + val channelsSent = send(channels.map(_.ann)) val nodesSent = send(nodes) - val updatesSent = send(updates) + val updatesSent = send(channels.flatMap(c => c.update_1_opt.toSeq ++ c.update_2_opt.toSeq)) log.info(s"sent all announcements to {}: channels={} updates={} nodes={}", remoteNodeId, channelsSent, updatesSent, nodesSent) stay diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala index 71396e096e..f9bfdae46a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Autoprobe.scala @@ -20,7 +20,7 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure, SendPayment} -import fr.acinq.eclair.router.{Announcements, Data} +import fr.acinq.eclair.router.{Announcements, Data, PublicChannel} import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{MilliSatoshi, NodeParams, randomBytes32, secureRandom} @@ -89,9 +89,10 @@ object Autoprobe { def pickPaymentDestination(nodeId: PublicKey, routingData: Data): Option[PublicKey] = { // we only pick direct peers with enabled public channels - val peers = routingData.updates + val peers = routingData.channels .collect { - case (desc, u) if desc.a == nodeId && Announcements.isEnabled(u.channelFlags) && routingData.channels.contains(u.shortChannelId) => desc.b // we only consider outgoing channels that are enabled and announced + case (shortChannelId, c@PublicChannel(ann, _, _, Some(u1), _)) + if c.getNodeIdSameSideAs(u1) == nodeId && Announcements.isEnabled(u1.channelFlags) && routingData.channels.exists(_._1 == shortChannelId) => ann.nodeId2 // we only consider outgoing channels that are enabled and announced } if (peers.isEmpty) { None diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index 153f2d0d26..1fc568bc1a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -23,6 +23,7 @@ import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.ChannelUpdate +import scala.collection.immutable.SortedMap import scala.collection.mutable object Graph { @@ -74,6 +75,7 @@ object Graph { targetNode: PublicKey, amount: MilliSatoshi, ignoredEdges: Set[ChannelDesc], + ignoredVertices: Set[PublicKey], extraEdges: Set[GraphEdge], pathsToFind: Int, wr: Option[WeightRatios], @@ -89,7 +91,7 @@ object Graph { // find the shortest path, k = 0 val initialWeight = RichWeight(cost = amount, 0, CltvExpiryDelta(0), 0) - val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, extraEdges, initialWeight, boundaries, currentBlockHeight, wr) + val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, initialWeight, boundaries, currentBlockHeight, wr) shortestPaths += WeightedPath(shortestPath, pathWeight(shortestPath, amount, isPartial = false, currentBlockHeight, wr)) // avoid returning a list with an empty path @@ -125,7 +127,7 @@ object Graph { val returningEdges = rootPathEdges.lastOption.map(last => graph.getEdgesBetween(last.desc.b, last.desc.a)).toSeq.flatten.map(_.desc) // find the "spur" path, a sub-path going from the spur edge to the target avoiding previously found sub-paths - val spurPath = dijkstraShortestPath(graph, spurEdge.desc.a, targetNode, ignoredEdges ++ edgesToIgnore.toSet ++ returningEdges.toSet, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr) + val spurPath = dijkstraShortestPath(graph, spurEdge.desc.a, targetNode, ignoredEdges ++ edgesToIgnore.toSet ++ returningEdges.toSet, ignoredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr) // if there wasn't a path the spur will be empty if (spurPath.nonEmpty) { @@ -178,6 +180,7 @@ object Graph { sourceNode: PublicKey, targetNode: PublicKey, ignoredEdges: Set[ChannelDesc], + ignoredVertices: Set[PublicKey], extraEdges: Set[GraphEdge], initialWeight: RichWeight, boundaries: RichWeight => Boolean, @@ -232,7 +235,7 @@ object Graph { if (edge.update.htlcMaximumMsat.forall(newMinimumKnownWeight.cost <= _) && newMinimumKnownWeight.cost >= edge.update.htlcMinimumMsat && boundaries(newMinimumKnownWeight) && // check if this neighbor edge would break off the 'boundaries' - !ignoredEdges.contains(edge.desc) + !ignoredEdges.contains(edge.desc) && !ignoredVertices.contains(neighbor) ) { // we call containsKey first because "getOrDefault" is not available in JDK7 @@ -533,21 +536,32 @@ object Graph { def apply(edge: GraphEdge): DirectedGraph = new DirectedGraph(Map()).addEdge(edge.desc, edge.update) def apply(edges: Seq[GraphEdge]): DirectedGraph = { - makeGraph(edges.map(e => e.desc -> e.update).toMap) + DirectedGraph().addEdges(edges.map(e => (e.desc, e.update))) } // optimized constructor - def makeGraph(descAndUpdates: Map[ChannelDesc, ChannelUpdate]): DirectedGraph = { + def makeGraph(channels: SortedMap[ShortChannelId, PublicChannel]): DirectedGraph = { // initialize the map with the appropriate size to avoid resizing during the graph initialization val mutableMap = new {} with mutable.HashMap[PublicKey, List[GraphEdge]] { - override def initialSize: Int = descAndUpdates.size + 1 + override def initialSize: Int = channels.size + 1 } // add all the vertices and edges in one go - descAndUpdates.foreach { case (desc, update) => - // create or update vertex (desc.b) and update its neighbor - mutableMap.put(desc.b, GraphEdge(desc, update) +: mutableMap.getOrElse(desc.b, List.empty[GraphEdge])) + channels.values.foreach { channel => + channel.update_1_opt.foreach { u1 => + val desc1 = Router.getDesc(u1, channel.ann) + addDescToMap(desc1, u1) + } + + channel.update_2_opt.foreach { u2 => + val desc2 = Router.getDesc(u2, channel.ann) + addDescToMap(desc2, u2) + } + } + + def addDescToMap(desc: ChannelDesc, u: ChannelUpdate) = { + mutableMap.put(desc.b, GraphEdge(desc, u) +: mutableMap.getOrElse(desc.b, List.empty[GraphEdge])) mutableMap.get(desc.a) match { case None => mutableMap += desc.a -> List.empty[GraphEdge] case _ => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index dd4eee4e01..0127f04580 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -39,7 +39,7 @@ import fr.acinq.eclair.wire._ import shapeless.HNil import scala.annotation.tailrec -import scala.collection.immutable.{SortedMap, TreeMap} +import scala.collection.immutable.SortedMap import scala.collection.{SortedSet, mutable} import scala.compat.Platform import scala.concurrent.duration._ @@ -63,8 +63,34 @@ case class RouterConf(randomizeRouteSelection: Boolean, searchRatioChannelCapacity: Double) case class ChannelDesc(shortChannelId: ShortChannelId, a: PublicKey, b: PublicKey) + +case class PublicChannel(ann: ChannelAnnouncement, fundingTxid: ByteVector32, capacity: Satoshi, update_1_opt: Option[ChannelUpdate], update_2_opt: Option[ChannelUpdate]) { + update_1_opt.foreach(u => assert(Announcements.isNode1(u.channelFlags))) + update_2_opt.foreach(u => assert(!Announcements.isNode1(u.channelFlags))) + + def getNodeIdSameSideAs(u: ChannelUpdate): PublicKey = if (Announcements.isNode1(u.channelFlags)) ann.nodeId1 else ann.nodeId2 + + def getChannelUpdateSameSideAs(u: ChannelUpdate): Option[ChannelUpdate] = if (Announcements.isNode1(u.channelFlags)) update_1_opt else update_2_opt + + def updateChannelUpdateSameSideAs(u: ChannelUpdate): PublicChannel = if (Announcements.isNode1(u.channelFlags)) copy(update_1_opt = Some(u)) else copy(update_2_opt = Some(u)) +} + +case class PrivateChannel(localNodeId: PublicKey, remoteNodeId: PublicKey, update_1_opt: Option[ChannelUpdate], update_2_opt: Option[ChannelUpdate]) { + val (nodeId1, nodeId2) = if (Announcements.isNode1(localNodeId, remoteNodeId)) (localNodeId, remoteNodeId) else (remoteNodeId, localNodeId) + + def getNodeIdSameSideAs(u: ChannelUpdate): PublicKey = if (Announcements.isNode1(u.channelFlags)) nodeId1 else nodeId2 + + def getChannelUpdateSameSideAs(u: ChannelUpdate): Option[ChannelUpdate] = if (Announcements.isNode1(u.channelFlags)) update_1_opt else update_2_opt + + def updateChannelUpdateSameSideAs(u: ChannelUpdate): PrivateChannel = if (Announcements.isNode1(u.channelFlags)) copy(update_1_opt = Some(u)) else copy(update_2_opt = Some(u)) +} + +case class AssistedChannel(extraHop: ExtraHop, nextNodeId: PublicKey) + case class Hop(nodeId: PublicKey, nextNodeId: PublicKey, lastUpdate: ChannelUpdate) + case class RouteParams(randomize: Boolean, maxFeeBase: MilliSatoshi, maxFeePct: Double, routeMaxLength: Int, routeMaxCltv: CltvExpiryDelta, ratios: Option[WeightRatios]) + case class RouteRequest(source: PublicKey, target: PublicKey, amount: MilliSatoshi, @@ -73,16 +99,23 @@ case class RouteRequest(source: PublicKey, ignoreChannels: Set[ChannelDesc] = Set.empty, routeParams: Option[RouteParams] = None) -case class FinalizeRoute(hops:Seq[PublicKey]) +case class FinalizeRoute(hops: Seq[PublicKey]) + case class RouteResponse(hops: Seq[Hop], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]) { require(hops.nonEmpty, "route cannot be empty") } + case class ExcludeChannel(desc: ChannelDesc) // this is used when we get a TemporaryChannelFailure, to give time for the channel to recover (note that exclusions are directed) case class LiftChannelExclusion(desc: ChannelDesc) + case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef, flags_opt: Option[QueryChannelRangeTlv]) + case object GetRoutingState -case class RoutingState(channels: Iterable[ChannelAnnouncement], updates: Iterable[ChannelUpdate], nodes: Iterable[NodeAnnouncement]) + +case class RoutingState(channels: Iterable[PublicChannel], nodes: Iterable[NodeAnnouncement]) + case class Stash(updates: Map[ChannelUpdate, Set[ActorRef]], nodes: Map[NodeAnnouncement, Set[ActorRef]]) + case class Rebroadcast(channels: Map[ChannelAnnouncement, Set[ActorRef]], updates: Map[ChannelUpdate, Set[ActorRef]], nodes: Map[NodeAnnouncement, Set[ActorRef]]) case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Long) @@ -90,13 +123,11 @@ case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Long) case class Sync(pending: List[RoutingMessage], total: Int) case class Data(nodes: Map[PublicKey, NodeAnnouncement], - channels: SortedMap[ShortChannelId, ChannelAnnouncement], - updates: Map[ChannelDesc, ChannelUpdate], + channels: SortedMap[ShortChannelId, PublicChannel], stash: Stash, rebroadcast: Rebroadcast, awaiting: Map[ChannelAnnouncement, Seq[ActorRef]], // note: this is a seq because we want to preserve order: first actor is the one who we need to send a tcp-ack when validation is done - privateChannels: Map[ShortChannelId, PublicKey], // short_channel_id -> node_id - privateUpdates: Map[ChannelDesc, ChannelUpdate], + privateChannels: Map[ShortChannelId, PrivateChannel], // short_channel_id -> node_id excludedChannels: Set[ChannelDesc], // those channels are temporarily excluded from route calculation, because their node returned a TemporaryChannelFailure graph: DirectedGraph, sync: Map[PublicKey, Sync] // keep tracks of channel range queries sent to each peer. If there is an entry in the map, it means that there is an ongoing query @@ -104,9 +135,11 @@ case class Data(nodes: Map[PublicKey, NodeAnnouncement], ) sealed trait State + case object NORMAL extends State case object TickBroadcast + case object TickPruneStaleChannels // @formatter:on @@ -140,28 +173,23 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ log.info("loading network announcements from db...") val channels = db.listChannels() val nodes = db.listNodes() - val updates = db.listChannelUpdates() - log.info("loaded from db: channels={} nodes={} updates={}", channels.size, nodes.size, updates.size) - val initChannels = channels.keys.foldLeft(TreeMap.empty[ShortChannelId, ChannelAnnouncement]) { case (m, c) => m + (c.shortChannelId -> c) } - val initChannelUpdates = updates.map { u => - val desc = getDesc(u, initChannels(u.shortChannelId)) - desc -> u - }.toMap + log.info("loaded from db: channels={} nodes={}", channels.size, nodes.size) + val initChannels = channels // this will be used to calculate routes - val graph = DirectedGraph.makeGraph(initChannelUpdates) + val graph = DirectedGraph.makeGraph(initChannels) val initNodes = nodes.map(n => (n.nodeId -> n)).toMap // send events for remaining channels/nodes - context.system.eventStream.publish(ChannelsDiscovered(initChannels.values.map(c => SingleChannelDiscovered(c, channels(c)._2)))) - context.system.eventStream.publish(ChannelUpdatesReceived(initChannelUpdates.values)) + context.system.eventStream.publish(ChannelsDiscovered(initChannels.values.map(pc => SingleChannelDiscovered(pc.ann, pc.capacity)))) + context.system.eventStream.publish(ChannelUpdatesReceived(initChannels.values.flatMap(pc => pc.update_1_opt ++ pc.update_2_opt ++ Nil))) context.system.eventStream.publish(NodesDiscovered(initNodes.values)) // watch the funding tx of all these channels // note: some of them may already have been spent, in that case we will receive the watch event immediately - initChannels.values.foreach { c => - val txid = channels(c)._1 - val TxCoordinates(_, _, outputIndex) = ShortChannelId.coordinates(c.shortChannelId) - val fundingOutputScript = write(pay2wsh(Scripts.multiSig2of2(c.bitcoinKey1, c.bitcoinKey2))) - watcher ! WatchSpentBasic(self, txid, outputIndex, fundingOutputScript, BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(c.shortChannelId)) + initChannels.values.foreach { pc => + val txid = pc.fundingTxid + val TxCoordinates(_, _, outputIndex) = ShortChannelId.coordinates(pc.ann.shortChannelId) + val fundingOutputScript = write(pay2wsh(Scripts.multiSig2of2(pc.ann.bitcoinKey1, pc.ann.bitcoinKey2))) + watcher ! WatchSpentBasic(self, txid, outputIndex, fundingOutputScript, BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(pc.ann.shortChannelId)) } // on restart we update our node announcement @@ -171,7 +199,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ log.info(s"initialization completed, ready to process messages") Try(initialized.map(_.success(Done))) - startWith(NORMAL, Data(initNodes, initChannels, initChannelUpdates, Stash(Map.empty, Map.empty), rebroadcast = Rebroadcast(channels = Map.empty, updates = Map.empty, nodes = Map.empty), awaiting = Map.empty, privateChannels = Map.empty, privateUpdates = Map.empty, excludedChannels = Set.empty, graph, sync = Map.empty)) + startWith(NORMAL, Data(initNodes, initChannels, Stash(Map.empty, Map.empty), rebroadcast = Rebroadcast(channels = Map.empty, updates = Map.empty, nodes = Map.empty), awaiting = Map.empty, privateChannels = Map.empty, excludedChannels = Set.empty, graph, sync = Map.empty)) } when(NORMAL) { @@ -199,7 +227,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // channel isn't announced and we never heard of it (maybe it is a private channel or maybe it is a public channel that doesn't yet have 6 confirmations) // let's create a corresponding private channel and process the channel_update log.info("adding unannounced local channel to remote={} shortChannelId={}", remoteNodeId, shortChannelId) - stay using handle(u, self, d.copy(privateChannels = d.privateChannels + (shortChannelId -> remoteNodeId))) + stay using handle(u, self, d.copy(privateChannels = d.privateChannels + (shortChannelId -> PrivateChannel(nodeParams.nodeId, remoteNodeId, None, None)))) } } @@ -219,14 +247,14 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ .removeEdge(desc1) .removeEdge(desc2) // and we remove the channel and channel_update from our state - stay using d.copy(privateChannels = d.privateChannels - shortChannelId, privateUpdates = d.privateUpdates - desc1 - desc2, graph = graph1) + stay using d.copy(privateChannels = d.privateChannels - shortChannelId, graph = graph1) } else { stay } case Event(GetRoutingState, d: Data) => log.info(s"getting valid announcements for $sender") - sender ! RoutingState(d.channels.values, d.updates.values, d.nodes.values) + sender ! RoutingState(d.channels.values, d.nodes.values) stay case Event(v@ValidateResult(c, _), d0) => @@ -235,10 +263,10 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ case _ => () } log.info("got validation result for shortChannelId={} (awaiting={} stash.nodes={} stash.updates={})", c.shortChannelId, d0.awaiting.size, d0.stash.nodes.size, d0.stash.updates.size) - val success = v match { + val publicChannel_opt = v match { case ValidateResult(c, Left(t)) => log.warning("validation failure for shortChannelId={} reason={}", c.shortChannelId, t.getMessage) - false + None case ValidateResult(c, Right((tx, UtxoStatus.Unspent))) => val TxCoordinates(_, _, outputIndex) = ShortChannelId.coordinates(c.shortChannelId) // let's check that the output is indeed a P2WSH multisig 2-of-2 of nodeid1 and nodeid2) @@ -249,7 +277,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ case Some(origins) => origins.foreach(_ ! InvalidAnnouncement(c)) case _ => () } - false + None } else { watcher ! WatchSpentBasic(self, tx, outputIndex, BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(c.shortChannelId)) // TODO: check feature bit set @@ -264,7 +292,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ val nodeAnn = Announcements.makeNodeAnnouncement(nodeParams.privateKey, nodeParams.alias, nodeParams.color, nodeParams.publicAddresses) self ! nodeAnn } - true + Some(PublicChannel(c, tx.txid, capacity, None, None)) } case ValidateResult(c, Right((tx, fundingTxStatus: UtxoStatus.Spent))) => if (fundingTxStatus.spendingTxConfirmed) { @@ -279,7 +307,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ } // there may be a record if we have just restarted db.removeChannel(c.shortChannelId) - false + None } // we also reprocess node and channel_update announcements related to channels that were just analyzed @@ -289,29 +317,31 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ val stash1 = d0.stash.copy(updates = d0.stash.updates -- reprocessUpdates.keys, nodes = d0.stash.nodes -- reprocessNodes.keys) // we remove channel from awaiting map val awaiting1 = d0.awaiting - c - if (success) { - // note: if the channel is graduating from private to public, the implementation (in the LocalChannelUpdate handler) guarantees that we will process a new channel_update - // right after the channel_announcement, channel_updates will be moved from private to public at that time - val d1 = d0.copy( - channels = d0.channels + (c.shortChannelId -> c), - privateChannels = d0.privateChannels - c.shortChannelId, // we remove fake announcements that we may have made before - rebroadcast = d0.rebroadcast.copy(channels = d0.rebroadcast.channels + (c -> d0.awaiting.getOrElse(c, Nil).toSet)), // we also add the newly validated channels to the rebroadcast queue - stash = stash1, - awaiting = awaiting1) - // we only reprocess updates and nodes if validation succeeded - val d2 = reprocessUpdates.foldLeft(d1) { - case (d, (u, origins)) => origins.foldLeft(d) { case (d, origin) => handle(u, origin, d) } // we reprocess the same channel_update for every origin (to preserve origin information) - } - val d3 = reprocessNodes.foldLeft(d2) { - case (d, (n, origins)) => origins.foldLeft(d) { case (d, origin) => handle(n, origin, d) } // we reprocess the same node_announcement for every origins (to preserve origin information) - } - stay using d3 - } else { - stay using d0.copy(stash = stash1, awaiting = awaiting1) + + publicChannel_opt match { + case Some(pc) => + // note: if the channel is graduating from private to public, the implementation (in the LocalChannelUpdate handler) guarantees that we will process a new channel_update + // right after the channel_announcement, channel_updates will be moved from private to public at that time + val d1 = d0.copy( + channels = d0.channels + (c.shortChannelId -> pc), + privateChannels = d0.privateChannels - c.shortChannelId, // we remove fake announcements that we may have made before + rebroadcast = d0.rebroadcast.copy(channels = d0.rebroadcast.channels + (c -> d0.awaiting.getOrElse(c, Nil).toSet)), // we also add the newly validated channels to the rebroadcast queue + stash = stash1, + awaiting = awaiting1) + // we only reprocess updates and nodes if validation succeeded + val d2 = reprocessUpdates.foldLeft(d1) { + case (d, (u, origins)) => origins.foldLeft(d) { case (d, origin) => handle(u, origin, d) } // we reprocess the same channel_update for every origin (to preserve origin information) + } + val d3 = reprocessNodes.foldLeft(d2) { + case (d, (n, origins)) => origins.foldLeft(d) { case (d, origin) => handle(n, origin, d) } // we reprocess the same node_announcement for every origins (to preserve origin information) + } + stay using d3 + case None => + stay using d0.copy(stash = stash1, awaiting = awaiting1) } case Event(WatchEventSpentBasic(BITCOIN_FUNDING_EXTERNAL_CHANNEL_SPENT(shortChannelId)), d) if d.channels.contains(shortChannelId) => - val lostChannel = d.channels(shortChannelId) + val lostChannel = d.channels(shortChannelId).ann log.info("funding tx of channelId={} has been spent", shortChannelId) // we need to remove nodes that aren't tied to any channels anymore val channels1 = d.channels - lostChannel.shortChannelId @@ -331,7 +361,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ db.removeNode(nodeId) context.system.eventStream.publish(NodeLost(nodeId)) } - stay using d.copy(nodes = d.nodes -- lostNodes, channels = d.channels - shortChannelId, updates = d.updates.filterKeys(_.shortChannelId != shortChannelId), graph = graph1) + stay using d.copy(nodes = d.nodes -- lostNodes, channels = d.channels - shortChannelId, graph = graph1) case Event(TickBroadcast, d) => if (d.rebroadcast.channels.isEmpty && d.rebroadcast.updates.isEmpty && d.rebroadcast.nodes.isEmpty) { @@ -345,28 +375,27 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ case Event(TickPruneStaleChannels, d) => // first we select channels that we will prune - val staleChannels = getStaleChannels(d.channels.values, d.updates) - // then we clean up the related channel updates - val staleUpdates = staleChannels.map(d.channels).flatMap(c => Seq(ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2), ChannelDesc(c.shortChannelId, c.nodeId2, c.nodeId1))) - // finally we remove nodes that aren't tied to any channels anymore (and deduplicate them) - val potentialStaleNodes = staleChannels.map(d.channels).flatMap(c => Set(c.nodeId1, c.nodeId2)).toSet - val channels1 = d.channels -- staleChannels + val staleChannels = getStaleChannels(d.channels.values) + val staleChannelIds = staleChannels.map(_.ann.shortChannelId) + // then we remove nodes that aren't tied to any channels anymore (and deduplicate them) + val potentialStaleNodes = staleChannels.flatMap(c => Set(c.ann.nodeId1, c.ann.nodeId2)).toSet + val channels1 = d.channels -- staleChannelIds // no need to iterate on all nodes, just on those that are affected by current pruning val staleNodes = potentialStaleNodes.filterNot(nodeId => hasChannels(nodeId, channels1.values)) // let's clean the db and send the events - db.removeChannels(staleChannels) // NB: this also removes channel updates + db.removeChannels(staleChannelIds) // NB: this also removes channel updates // we keep track of recently pruned channels so we don't revalidate them (zombie churn) - db.addToPruned(staleChannels) - staleChannels.foreach { shortChannelId => + db.addToPruned(staleChannelIds) + staleChannelIds.foreach { shortChannelId => log.info("pruning shortChannelId={} (stale)", shortChannelId) context.system.eventStream.publish(ChannelLost(shortChannelId)) } val staleChannelsToRemove = new mutable.MutableList[ChannelDesc] - staleChannels.map(d.channels).foreach(ca => { - staleChannelsToRemove += ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2) - staleChannelsToRemove += ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1) + staleChannels.foreach(ca => { + staleChannelsToRemove += ChannelDesc(ca.ann.shortChannelId, ca.ann.nodeId1, ca.ann.nodeId2) + staleChannelsToRemove += ChannelDesc(ca.ann.shortChannelId, ca.ann.nodeId2, ca.ann.nodeId1) }) val graph1 = d.graph.removeEdges(staleChannelsToRemove) @@ -376,7 +405,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ db.removeNode(nodeId) context.system.eventStream.publish(NodeLost(nodeId)) } - stay using d.copy(nodes = d.nodes -- staleNodes, channels = channels1, updates = d.updates -- staleUpdates, graph = graph1) + stay using d.copy(nodes = d.nodes -- staleNodes, channels = channels1, graph = graph1) case Event(ExcludeChannel(desc@ChannelDesc(shortChannelId, nodeId, _)), d) => val banDuration = nodeParams.routerConf.channelExcludeDuration @@ -393,15 +422,16 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ stay case Event('channels, d) => - sender ! d.channels.values + sender ! d.channels.values.map(_.ann) stay - case Event('updates, d) => - sender ! (d.updates ++ d.privateUpdates).values + case Event('channelsMap, d) => + sender ! d.channels stay - case Event('updatesMap, d) => - sender ! (d.updates ++ d.privateUpdates) + case Event('updates, d) => + val updates: Iterable[ChannelUpdate] = d.channels.values.flatMap(d => d.update_1_opt ++ d.update_2_opt) ++ d.privateChannels.values.flatMap(d => d.update_1_opt ++ d.update_2_opt) + sender ! updates stay case Event('data, d) => @@ -418,17 +448,15 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ case Event(RouteRequest(start, end, amount, assistedRoutes, ignoreNodes, ignoreChannels, params_opt), d) => // we convert extra routing info provided in the payment request to fake channel_update // it takes precedence over all other channel_updates we know - val assistedUpdates = assistedRoutes.flatMap(toFakeUpdates(_, end)).toMap - // we also filter out updates corresponding to channels/nodes that are blacklisted for this particular request - // TODO: in case of duplicates, d.updates will be overridden by assistedUpdates even if they are more recent! - val ignoredUpdates = getIgnoredChannelDesc(d.updates ++ d.privateUpdates ++ assistedUpdates, ignoreNodes) ++ ignoreChannels ++ d.excludedChannels - val extraEdges = assistedUpdates.map { case (c, u) => GraphEdge(c, u) }.toSet + val assistedChannels: Map[ShortChannelId, AssistedChannel] = assistedRoutes.flatMap(toAssistedChannels(_, end)).toMap + val extraEdges = assistedChannels.values.map(ac => GraphEdge(ChannelDesc(ac.extraHop.shortChannelId, ac.extraHop.nodeId, ac.nextNodeId), toFakeUpdate(ac.extraHop))).toSet + val ignoredEdges = ignoreChannels ++ d.excludedChannels val params = params_opt.getOrElse(defaultRouteParams) val routesToFind = if (params.randomize) DEFAULT_ROUTES_COUNT else 1 - log.info(s"finding a route $start->$end with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedUpdates.keys.mkString(","), ignoreNodes.map(_.value).mkString(","), ignoreChannels.mkString(","), d.excludedChannels.mkString(",")) + log.info(s"finding a route $start->$end with assistedChannels={} ignoreNodes={} ignoreChannels={} excludedChannels={}", assistedChannels.keys.mkString(","), ignoreNodes.map(_.value).mkString(","), ignoreChannels.mkString(","), d.excludedChannels.mkString(",")) log.info(s"finding a route with randomize={} params={}", routesToFind > 1, params) - findRoute(d.graph, start, end, amount, numRoutes = routesToFind, extraEdges = extraEdges, ignoredEdges = ignoredUpdates.toSet, routeParams = params) + findRoute(d.graph, start, end, amount, numRoutes = routesToFind, extraEdges = extraEdges, ignoredEdges = ignoredEdges, ignoredVertices = ignoreNodes, routeParams = params) .map(r => sender ! RouteResponse(r, ignoreNodes, ignoreChannels)) .recover { case t => sender ! Status.Failure(t) } stay @@ -511,25 +539,24 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ sender ! TransportHandler.ReadAck(routingMessage) log.info("received query_channel_range with firstBlockNum={} numberOfBlocks={} extendedQueryFlags_opt={}", firstBlockNum, numberOfBlocks, extendedQueryFlags_opt) // keep channel ids that are in [firstBlockNum, firstBlockNum + numberOfBlocks] - val shortChannelIds: SortedSet[ShortChannelId] = d.channels.keySet.filter(keep(firstBlockNum, numberOfBlocks, _, d.channels, d.updates)) + val shortChannelIds: SortedSet[ShortChannelId] = d.channels.keySet.filter(keep(firstBlockNum, numberOfBlocks, _)) log.info("replying with {} items for range=({}, {})", shortChannelIds.size, firstBlockNum, numberOfBlocks) split(shortChannelIds) .foreach(chunk => { val (timestamps, checksums) = routingMessage.queryFlags_opt match { case Some(extension) if extension.wantChecksums | extension.wantTimestamps => // we always compute timestamps and checksums even if we don't need both, overhead is negligible - val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(d.channels, d.updates)).unzip + val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(d.channels)).unzip val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(nodeParams.routerConf.encodingType, timestamps)) else None val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None (encodedTimestamps, encodedChecksums) case _ => (None, None) } - val reply = ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, + transport ! ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, complete = 1, shortChannelIds = EncodedShortChannelIds(nodeParams.routerConf.encodingType, chunk.shortChannelIds), timestamps = timestamps, checksums = checksums) - transport ! reply }) stay @@ -541,7 +568,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ ids match { case Nil => acc.reverse case head :: tail => - val flag = computeFlag(d.channels, d.updates)(head, timestamps.headOption, checksums.headOption, nodeParams.routerConf.requestNodeAnnouncements) + val flag = computeFlag(d.channels)(head, timestamps.headOption, checksums.headOption, nodeParams.routerConf.requestNodeAnnouncements) // 0 means nothing to query, just don't include it val acc1 = if (flag != 0) ShortChannelIdAndFlag(head, flag) :: acc else acc loop(tail, timestamps.drop(1), checksums.drop(1), acc1) @@ -572,7 +599,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ )) .toList val (sync1, replynow_opt) = addToSync(d.sync, remoteNodeId, replies) - // we only send a rely right away if there were no pending requests + // we only send a reply right away if there were no pending requests replynow_opt.foreach(transport ! _) context.system.eventStream.publish(syncProgress(sync1)) stay using d.copy(sync = sync1) @@ -585,7 +612,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ var updateCount = 0 var nodeCount = 0 - Router.processChannelQuery(d.nodes, d.channels, d.updates)( + Router.processChannelQuery(d.nodes, d.channels)( shortChannelIds.array, flags, ca => { @@ -651,7 +678,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ context.system.eventStream.publish(NodeUpdated(n)) db.updateNode(n) d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes + (n -> Set(origin)))) - } else if (d.channels.values.exists(c => isRelatedTo(c, n.nodeId))) { + } else if (d.channels.values.exists(c => isRelatedTo(c.ann, n.nodeId))) { log.debug("added node nodeId={}", n.nodeId) context.system.eventStream.publish(NodesDiscovered(n :: Nil)) db.addNode(n) @@ -670,8 +697,8 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ if (d.channels.contains(u.shortChannelId)) { // related channel is already known (note: this means no related channel_update is in the stash) val publicChannel = true - val c = d.channels(u.shortChannelId) - val desc = getDesc(u, c) + val pc = d.channels(u.shortChannelId) + val desc = getDesc(u, pc.ann) if (d.rebroadcast.updates.contains(u)) { log.debug("ignoring {} (pending rebroadcast)", u) val origins = d.rebroadcast.updates(u) + origin @@ -679,30 +706,30 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ } else if (isStale(u)) { log.debug("ignoring {} (stale)", u) d - } else if (d.updates.contains(desc) && d.updates(desc).timestamp >= u.timestamp) { + } else if (pc.getChannelUpdateSameSideAs(u).exists(_.timestamp >= u.timestamp)) { log.debug("ignoring {} (duplicate)", u) d - } else if (!Announcements.checkSig(u, desc.a)) { + } else if (!Announcements.checkSig(u, pc.getNodeIdSameSideAs(u))) { log.warning("bad signature for announcement shortChannelId={} {}", u.shortChannelId, u) origin ! InvalidSignature(u) d - } else if (d.updates.contains(desc)) { + } else if (pc.getChannelUpdateSameSideAs(u).isDefined) { log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) - db.updateChannelUpdate(u) + db.updateChannel(u) // update the graph val graph1 = Announcements.isEnabled(u.channelFlags) match { case true => d.graph.removeEdge(desc).addEdge(desc, u) case false => d.graph.removeEdge(desc) // if the channel is now disabled, we remove it from the graph } - d.copy(updates = d.updates + (desc -> u), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) + d.copy(channels = d.channels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) } else { log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) - db.addChannelUpdate(u) + db.updateChannel(u) // we also need to update the graph val graph1 = d.graph.addEdge(desc, u) - d.copy(updates = d.updates + (desc -> u), privateUpdates = d.privateUpdates - desc, rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) + d.copy(channels = d.channels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), privateChannels = d.privateChannels - u.shortChannelId, rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> Set(origin))), graph = graph1) } } else if (d.awaiting.keys.exists(c => c.shortChannelId == u.shortChannelId)) { // channel is currently being validated @@ -716,31 +743,30 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ } } else if (d.privateChannels.contains(u.shortChannelId)) { val publicChannel = false - val remoteNodeId = d.privateChannels(u.shortChannelId) - val (a, b) = if (Announcements.isNode1(nodeParams.nodeId, remoteNodeId)) (nodeParams.nodeId, remoteNodeId) else (remoteNodeId, nodeParams.nodeId) - val desc = if (Announcements.isNode1(u.channelFlags)) ChannelDesc(u.shortChannelId, a, b) else ChannelDesc(u.shortChannelId, b, a) + val pc = d.privateChannels(u.shortChannelId) + val desc = if (Announcements.isNode1(u.channelFlags)) ChannelDesc(u.shortChannelId, pc.nodeId1, pc.nodeId2) else ChannelDesc(u.shortChannelId, pc.nodeId2, pc.nodeId1) if (isStale(u)) { log.debug("ignoring {} (stale)", u) d - } else if (d.updates.contains(desc) && d.updates(desc).timestamp >= u.timestamp) { + } else if (pc.getChannelUpdateSameSideAs(u).exists(_.timestamp >= u.timestamp)) { log.debug("ignoring {} (already know same or newer)", u) d } else if (!Announcements.checkSig(u, desc.a)) { log.warning("bad signature for announcement shortChannelId={} {}", u.shortChannelId, u) origin ! InvalidSignature(u) d - } else if (d.privateUpdates.contains(desc)) { + } else if (pc.getChannelUpdateSameSideAs(u).isDefined) { log.debug("updated channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) // we also need to update the graph val graph1 = d.graph.removeEdge(desc).addEdge(desc, u) - d.copy(privateUpdates = d.privateUpdates + (desc -> u), graph = graph1) + d.copy(privateChannels = d.privateChannels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), graph = graph1) } else { log.debug("added channel_update for shortChannelId={} public={} flags={} {}", u.shortChannelId, publicChannel, u.channelFlags, u) context.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) // we also need to update the graph val graph1 = d.graph.addEdge(desc, u) - d.copy(privateUpdates = d.privateUpdates + (desc -> u), graph = graph1) + d.copy(privateChannels = d.privateChannels + (u.shortChannelId -> pc.updateChannelUpdateSameSideAs(u)), graph = graph1) } } else if (db.isPruned(u.shortChannelId) && !isStale(u)) { // the channel was recently pruned, but if we are here, it means that the update is not stale so this is the case @@ -794,11 +820,12 @@ object Router { // what matters is that the `disable` bit is 0 so that this update doesn't get filtered out ChannelUpdate(signature = ByteVector64.Zeroes, chainHash = ByteVector32.Zeroes, extraHop.shortChannelId, Platform.currentTime.milliseconds.toSeconds, messageFlags = 0, channelFlags = 0, extraHop.cltvExpiryDelta, htlcMinimumMsat = MilliSatoshi(0), MilliSatoshi(extraHop.feeBaseMsat), extraHop.feeProportionalMillionths, None) - def toFakeUpdates(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey): Map[ChannelDesc, ChannelUpdate] = { + + def toAssistedChannels(extraRoute: Seq[ExtraHop], targetNodeId: PublicKey): Map[ShortChannelId, AssistedChannel] = { // BOLT 11: "For each entry, the pubkey is the node ID of the start of the channel", and the last node is the destination val nextNodeIds = extraRoute.map(_.nodeId).drop(1) :+ targetNodeId extraRoute.zip(nextNodeIds).map { - case (extraHop: ExtraHop, nextNodeId) => (ChannelDesc(extraHop.shortChannelId, extraHop.nodeId, nextNodeId) -> toFakeUpdate(extraHop)) + case (extraHop: ExtraHop, nextNodeId) => extraHop.shortChannelId -> AssistedChannel(extraHop, nextNodeId) }.toMap } @@ -809,7 +836,7 @@ object Router { def isRelatedTo(c: ChannelAnnouncement, nodeId: PublicKey) = nodeId == c.nodeId1 || nodeId == c.nodeId2 - def hasChannels(nodeId: PublicKey, channels: Iterable[ChannelAnnouncement]): Boolean = channels.exists(c => isRelatedTo(c, nodeId)) + def hasChannels(nodeId: PublicKey, channels: Iterable[PublicChannel]): Boolean = channels.exists(c => isRelatedTo(c.ann, nodeId)) def isStale(u: ChannelUpdate): Boolean = isStale(u.timestamp) @@ -845,19 +872,12 @@ object Router { blockHeight < staleThresholdBlocks && update1_opt.map(isStale).getOrElse(true) && update2_opt.map(isStale).getOrElse(true) } - def getStaleChannels(channels: Iterable[ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate]): Iterable[ShortChannelId] = { - val staleChannels = channels.filter { c => - val update1 = updates.get(ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2)) - val update2 = updates.get(ChannelDesc(c.shortChannelId, c.nodeId2, c.nodeId1)) - isStale(c, update1, update2) - } - staleChannels.map(_.shortChannelId) - } + def getStaleChannels(channels: Iterable[PublicChannel]): Iterable[PublicChannel] = channels.filter(data => isStale(data.ann, data.update_1_opt, data.update_2_opt)) /** * Filters channels that we want to send to nodes asking for a channel range */ - def keep(firstBlockNum: Long, numberOfBlocks: Long, id: ShortChannelId, channels: Map[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate]): Boolean = { + def keep(firstBlockNum: Long, numberOfBlocks: Long, id: ShortChannelId): Boolean = { val TxCoordinates(height, _, _) = ShortChannelId.coordinates(id) height >= firstBlockNum && height <= (firstBlockNum + numberOfBlocks) } @@ -881,6 +901,7 @@ object Router { theirsIsMoreRecent && !theirsIsStale case (None, Some(theirChecksum)) => // if we only have their checksum, we request their channel_update if it is different from ours + // NB: a zero checksum means that they don't have the data val areDifferent = theirChecksum != 0 && ourChecksum != theirChecksum areDifferent case (None, None) => @@ -889,7 +910,7 @@ object Router { } } - def computeFlag(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])( + def computeFlag(channels: SortedMap[ShortChannelId, PublicChannel])( shortChannelId: ShortChannelId, theirTimestamps_opt: Option[ReplyChannelRangeTlv.Timestamps], theirChecksums_opt: Option[ReplyChannelRangeTlv.Checksums], @@ -902,7 +923,7 @@ object Router { INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 } else { // we already know this channel - val (ourTimestamps, ourChecksums) = Router.getChannelDigestInfo(channels, updates)(shortChannelId) + val (ourTimestamps, ourChecksums) = Router.getChannelDigestInfo(channels)(shortChannelId) // if they don't provide timestamps or checksums, we set appropriate default values: // - we assume their timestamp is more recent than ours by setting timestamp = Long.MaxValue // - we assume their update is different from ours by setting checkum = Long.MaxValue (NB: our default value for checksum is 0) @@ -920,8 +941,7 @@ object Router { * Handle a query message, which includes a list of channel ids and flags. * * @param nodes node id -> node announcement - * @param channels channel id -> channel announcement - * @param updates channel description -> channel update + * @param channels channel id -> channel announcement + updates * @param ids list of channel ids * @param flags list of query flags, either empty one flag per channel id * @param onChannel called when a channel announcement matches (i.e. its bit is set in the query flag and we have it) @@ -930,8 +950,7 @@ object Router { * */ def processChannelQuery(nodes: Map[PublicKey, NodeAnnouncement], - channels: SortedMap[ShortChannelId, ChannelAnnouncement], - updates: Map[ChannelDesc, ChannelUpdate])( + channels: SortedMap[ShortChannelId, PublicChannel])( ids: List[ShortChannelId], flags: List[Long], onChannel: ChannelAnnouncement => Unit, @@ -951,7 +970,7 @@ object Router { var numca1 = numca var numcu1 = numcu var sent1 = nodesSent - val ca = channels(head) + val pc = channels(head) val flag_opt = flags.headOption // no flag means send everything @@ -962,28 +981,28 @@ object Router { val includeNode2 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement2) if (includeChannel) { - onChannel(ca) + onChannel(pc.ann) } if (includeUpdate1) { - updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)).foreach { u => + pc.update_1_opt.foreach { u => onUpdate(u) } } if (includeUpdate2) { - updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)).foreach { u => + pc.update_2_opt.foreach { u => onUpdate(u) } } - if (includeNode1 && !sent1.contains(ca.nodeId1)) { - nodes.get(ca.nodeId1).foreach { n => + if (includeNode1 && !sent1.contains(pc.ann.nodeId1)) { + nodes.get(pc.ann.nodeId1).foreach { n => onNode(n) - sent1 = sent1 + ca.nodeId1 + sent1 = sent1 + pc.ann.nodeId1 } } - if (includeNode2 && !sent1.contains(ca.nodeId2)) { - nodes.get(ca.nodeId2).foreach { n => + if (includeNode2 && !sent1.contains(pc.ann.nodeId2)) { + nodes.get(pc.ann.nodeId2).foreach { n => onNode(n) - sent1 = sent1 + ca.nodeId2 + sent1 = sent1 + pc.ann.nodeId2 } } loop(tail, flags.drop(1), numca1, numcu1, sent1) @@ -1013,44 +1032,24 @@ object Router { /** * This method is used after a payment failed, and we want to exclude some nodes that we know are failing */ - def getIgnoredChannelDesc(updates: Map[ChannelDesc, ChannelUpdate], ignoreNodes: Set[PublicKey]): Iterable[ChannelDesc] = { + def getIgnoredChannelDesc(channels: Map[ShortChannelId, PublicChannel], ignoreNodes: Set[PublicKey]): Iterable[ChannelDesc] = { val desc = if (ignoreNodes.isEmpty) { Iterable.empty[ChannelDesc] } else { // expensive, but node blacklisting shouldn't happen often - updates.keys.filter(desc => ignoreNodes.contains(desc.a) || ignoreNodes.contains(desc.b)) + channels.values + .filter(channelData => ignoreNodes.contains(channelData.ann.nodeId1) || ignoreNodes.contains(channelData.ann.nodeId2)) + .flatMap(channelData => Vector(ChannelDesc(channelData.ann.shortChannelId, channelData.ann.nodeId1, channelData.ann.nodeId2), ChannelDesc(channelData.ann.shortChannelId, channelData.ann.nodeId2, channelData.ann.nodeId1))) } desc } - /** - * - * @param channels id -> announcement map - * @param updates channel updates - * @param id short channel id - * @return the timestamp of the most recent update for this channel id, 0 if we don't have any - */ - def getTimestamp(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(id: ShortChannelId): Long = { - val ca = channels(id) - val opt1 = updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)) - val opt2 = updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)) - val timestamp = (opt1, opt2) match { - case (Some(u1), Some(u2)) => Math.max(u1.timestamp, u2.timestamp) - case (Some(u1), None) => u1.timestamp - case (None, Some(u2)) => u2.timestamp - case (None, None) => 0L - } - timestamp - } - - def getChannelDigestInfo(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(shortChannelId: ShortChannelId): (ReplyChannelRangeTlv.Timestamps, ReplyChannelRangeTlv.Checksums) = { + def getChannelDigestInfo(channels: SortedMap[ShortChannelId, PublicChannel])(shortChannelId: ShortChannelId): (ReplyChannelRangeTlv.Timestamps, ReplyChannelRangeTlv.Checksums) = { val c = channels(shortChannelId) - val u1_opt = updates.get(ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2)) - val u2_opt = updates.get(ChannelDesc(c.shortChannelId, c.nodeId2, c.nodeId1)) - val timestamp1 = u1_opt.map(_.timestamp).getOrElse(0L) - val timestamp2 = u2_opt.map(_.timestamp).getOrElse(0L) - val checksum1 = u1_opt.map(getChecksum).getOrElse(0L) - val checksum2 = u2_opt.map(getChecksum).getOrElse(0L) + val timestamp1 = c.update_1_opt.map(_.timestamp).getOrElse(0L) + val timestamp2 = c.update_2_opt.map(_.timestamp).getOrElse(0L) + val checksum1 = c.update_1_opt.map(getChecksum).getOrElse(0L) + val checksum2 = c.update_2_opt.map(getChecksum).getOrElse(0L) (ReplyChannelRangeTlv.Timestamps(timestamp1 = timestamp1, timestamp2 = timestamp2), ReplyChannelRangeTlv.Checksums(checksum1 = checksum1, checksum2 = checksum2)) } @@ -1157,6 +1156,7 @@ object Router { numRoutes: Int, extraEdges: Set[GraphEdge] = Set.empty, ignoredEdges: Set[ChannelDesc] = Set.empty, + ignoredVertices: Set[PublicKey] = Set.empty, routeParams: RouteParams): Try[Seq[Hop]] = Try { if (localNodeId == targetNodeId) throw CannotRouteToSelf @@ -1180,9 +1180,9 @@ object Router { feeOk(weight.cost - amount, amount) && lengthOk(weight.length) && cltvOk(weight.cltv) } - val foundRoutes = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries).toList match { + val foundRoutes = Graph.yenKshortestPaths(g, localNodeId, targetNodeId, amount, ignoredEdges, ignoredVertices, extraEdges, numRoutes, routeParams.ratios, currentBlockHeight, boundaries).toList match { case Nil if routeParams.routeMaxLength < ROUTE_MAX_LENGTH => // if not found within the constraints we relax and repeat the search - return findRoute(g, localNodeId, targetNodeId, amount, numRoutes, extraEdges, ignoredEdges, routeParams.copy(routeMaxLength = ROUTE_MAX_LENGTH, routeMaxCltv = DEFAULT_ROUTE_MAX_CLTV)) + return findRoute(g, localNodeId, targetNodeId, amount, numRoutes, extraEdges, ignoredEdges, ignoredVertices, routeParams.copy(routeMaxLength = ROUTE_MAX_LENGTH, routeMaxCltv = DEFAULT_ROUTE_MAX_CLTV)) case Nil => throw RouteNotFound case routes => routes.find(_.path.size == 1) match { case Some(directRoute) => directRoute :: Nil diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala index a78fef82d8..e7c583b1b8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala @@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets import com.google.common.base.Charsets import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} +import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.ByteVector @@ -223,6 +224,8 @@ case class ChannelUpdate(signature: ByteVector64, htlcMaximumMsat: Option[MilliSatoshi], unknownFields: ByteVector = ByteVector.empty) extends RoutingMessage with HasTimestamp with HasChainHash { require(((messageFlags & 1) != 0) == htlcMaximumMsat.isDefined, "htlcMaximumMsat is not consistent with messageFlags") + + def isNode1 = Announcements.isNode1(channelFlags) } // @formatter:off @@ -233,11 +236,9 @@ object EncodingType { } // @formatter:on - case class EncodedShortChannelIds(encoding: EncodingType, array: List[ShortChannelId]) - case class QueryShortChannelIds(chainHash: ByteVector32, shortChannelIds: EncodedShortChannelIds, tlvStream: TlvStream[QueryShortChannelIdsTlv] = TlvStream.empty) extends RoutingMessage with HasChainHash { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index b87b1ebcb6..54cb65c948 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -51,8 +51,8 @@ case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[Gen /** * * @tparam R input type parameter, must be a subtype of the main TLV type - * @return the TLV record of of type that matches the input type parameter if any (there can be at most one, since BOLTs specify - * that TLV records are supposed to be unique + * @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify + * that TLV records are supposed to be unique) */ def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala index f326d3eb89..b06ff1866c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/SqliteNetworkDbSpec.scala @@ -16,13 +16,18 @@ package fr.acinq.eclair.db +import java.sql.{Connection, DriverManager} + +import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{Block, Crypto, Satoshi} import fr.acinq.eclair.db.sqlite.SqliteNetworkDb -import fr.acinq.eclair.router.Announcements +import fr.acinq.eclair.db.sqlite.SqliteUtils._ +import fr.acinq.eclair.router.{Announcements, PublicChannel} import fr.acinq.eclair.wire.{Color, NodeAddress, Tor2} import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, ShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.FunSuite -import org.sqlite.SQLiteException + +import scala.collection.SortedMap class SqliteNetworkDbSpec extends FunSuite { @@ -35,6 +40,45 @@ class SqliteNetworkDbSpec extends FunSuite { val db2 = new SqliteNetworkDb(sqlite) } + test("migration test 1->2") { + val sqlite = TestConstants.sqliteInMemory() + + using(sqlite.createStatement()) { statement => + getVersion(statement, "network", 1) // this will set version to 1 + statement.execute("PRAGMA foreign_keys = ON") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)") + } + + + using(sqlite.createStatement()) { statement => + assert(getVersion(statement, "network", 2) == 1) + } + + // first round: this will trigger a migration + simpleTest(sqlite) + + using(sqlite.createStatement()) { statement => + assert(getVersion(statement, "network", 2) == 2) + } + + using(sqlite.createStatement()) { statement => + statement.executeUpdate("DELETE FROM nodes") + statement.executeUpdate("DELETE FROM channels") + } + + // second round: no migration + simpleTest(sqlite) + + using(sqlite.createStatement()) { statement => + assert(getVersion(statement, "network", 2) == 2) + } + + } + test("add/remove/list nodes") { val sqlite = TestConstants.sqliteInMemory() val db = new SqliteNetworkDb(sqlite) @@ -60,15 +104,25 @@ class SqliteNetworkDbSpec extends FunSuite { assert(node_4.addresses == List(Tor2("aaaqeayeaudaocaj", 42000))) } - test("add/remove/list channels and channel_updates") { - val sqlite = TestConstants.sqliteInMemory() + def simpleTest(sqlite: Connection) = { val db = new SqliteNetworkDb(sqlite) def sig = Crypto.sign(randomBytes32, randomKey) - val channel_1 = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) - val channel_2 = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(43), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) - val channel_3 = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(44), randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) + def generatePubkeyHigherThan(priv: PrivateKey) = { + var res = priv + while(!Announcements.isNode1(priv.publicKey, res.publicKey)) res = randomKey + res + } + + // in order to differentiate channel_updates 1/2 we order public keys + val a = randomKey + val b = generatePubkeyHigherThan(a) + val c = generatePubkeyHigherThan(b) + + val channel_1 = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(42), a.publicKey, b.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) + val channel_2 = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(43), a.publicKey, c.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) + val channel_3 = Announcements.makeChannelAnnouncement(Block.RegtestGenesisBlock.hash, ShortChannelId(44), b.publicKey, c.publicKey, randomKey.publicKey, randomKey.publicKey, sig, sig, sig, sig) val txid_1 = randomBytes32 val txid_2 = randomBytes32 @@ -81,24 +135,34 @@ class SqliteNetworkDbSpec extends FunSuite { assert(db.listChannels().size === 1) db.addChannel(channel_2, txid_2, capacity) db.addChannel(channel_3, txid_3, capacity) - assert(db.listChannels().toSet === Set((channel_1, (txid_1, capacity)), (channel_2, (txid_2, capacity)), (channel_3, (txid_3, capacity)))) + assert(db.listChannels() === SortedMap( + channel_1.shortChannelId -> PublicChannel(channel_1, txid_1, capacity, None, None), + channel_2.shortChannelId -> PublicChannel(channel_2, txid_2, capacity, None, None), + channel_3.shortChannelId -> PublicChannel(channel_3, txid_3, capacity, None, None))) db.removeChannel(channel_2.shortChannelId) - assert(db.listChannels().toSet === Set((channel_1, (txid_1, capacity)), (channel_3, (txid_3, capacity)))) - - val channel_update_1 = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, randomKey, randomKey.publicKey, ShortChannelId(42), CltvExpiryDelta(5), MilliSatoshi(7000000), MilliSatoshi(50000), 100, MilliSatoshi(500000000L), true) - val channel_update_2 = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, randomKey, randomKey.publicKey, ShortChannelId(43), CltvExpiryDelta(5), MilliSatoshi(7000000), MilliSatoshi(50000), 100, MilliSatoshi(500000000L), true) - val channel_update_3 = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, randomKey, randomKey.publicKey, ShortChannelId(44), CltvExpiryDelta(5), MilliSatoshi(7000000), MilliSatoshi(50000), 100, MilliSatoshi(500000000L), true) - - assert(db.listChannelUpdates().toSet === Set.empty) - db.addChannelUpdate(channel_update_1) - db.addChannelUpdate(channel_update_1) // duplicate is ignored - assert(db.listChannelUpdates().size === 1) - intercept[SQLiteException](db.addChannelUpdate(channel_update_2)) - db.addChannelUpdate(channel_update_3) + assert(db.listChannels() === SortedMap( + channel_1.shortChannelId -> PublicChannel(channel_1, txid_1, capacity, None, None), + channel_3.shortChannelId -> PublicChannel(channel_3, txid_3, capacity, None, None))) + + val channel_update_1 = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, a, b.publicKey, ShortChannelId(42), CltvExpiryDelta(5), MilliSatoshi(7000000), MilliSatoshi(50000), 100, MilliSatoshi(500000000L), true) + val channel_update_2 = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, b, a.publicKey, ShortChannelId(42), CltvExpiryDelta(5), MilliSatoshi(7000000), MilliSatoshi(50000), 100, MilliSatoshi(500000000L), true) + val channel_update_3 = Announcements.makeChannelUpdate(Block.RegtestGenesisBlock.hash, b, c.publicKey, ShortChannelId(44), CltvExpiryDelta(5), MilliSatoshi(7000000), MilliSatoshi(50000), 100, MilliSatoshi(500000000L), true) + + db.updateChannel(channel_update_1) + db.updateChannel(channel_update_1) // duplicate is ignored + db.updateChannel(channel_update_2) + db.updateChannel(channel_update_3) + assert(db.listChannels() === SortedMap( + channel_1.shortChannelId -> PublicChannel(channel_1, txid_1, capacity, Some(channel_update_1), Some(channel_update_2)), + channel_3.shortChannelId -> PublicChannel(channel_3, txid_3, capacity, Some(channel_update_3), None))) db.removeChannel(channel_3.shortChannelId) - assert(db.listChannels().toSet === Set((channel_1, (txid_1, capacity)))) - assert(db.listChannelUpdates().toSet === Set(channel_update_1)) - db.updateChannelUpdate(channel_update_1) + assert(db.listChannels() === SortedMap( + channel_1.shortChannelId -> PublicChannel(channel_1, txid_1, capacity, Some(channel_update_1), Some(channel_update_2)))) + } + + test("add/remove/list channels and channel_updates") { + val sqlite = TestConstants.sqliteInMemory() + simpleTest(sqlite) } test("remove many channels") { @@ -114,14 +178,12 @@ class SqliteNetworkDbSpec extends FunSuite { val updates = shortChannelIds.map(id => template.copy(shortChannelId = id)) val txid = randomBytes32 channels.foreach(ca => db.addChannel(ca, txid, capacity)) - updates.foreach(u => db.addChannelUpdate(u)) - assert(db.listChannels().keySet === channels.toSet) - assert(db.listChannelUpdates() === updates) + updates.foreach(u => db.updateChannel(u)) + assert(db.listChannels().keySet === channels.map(_.shortChannelId).toSet) val toDelete = channels.map(_.shortChannelId).drop(500).take(2500) db.removeChannels(toDelete) - assert(db.listChannels().keySet === channels.filterNot(a => toDelete.contains(a.shortChannelId)).toSet) - assert(db.listChannelUpdates().toSet === updates.filterNot(u => toDelete.contains(u.shortChannelId)).toSet) + assert(db.listChannels().keySet === (channels.map(_.shortChannelId).toSet -- toDelete)) } test("prune many channels") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala index fda650db0d..219952906a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/IntegrationSpec.scala @@ -23,7 +23,7 @@ import akka.actor.{ActorRef, ActorSystem} import akka.testkit.{TestKit, TestProbe} import com.google.common.net.HostAndPort import com.typesafe.config.{Config, ConfigFactory} -import fr.acinq.bitcoin.Crypto.PrivateKey +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{Base58, Base58Check, Bech32, Block, ByteVector32, Crypto, OP_0, OP_CHECKSIG, OP_DUP, OP_EQUAL, OP_EQUALVERIFY, OP_HASH160, OP_PUSHDATA, Satoshi, Script, ScriptFlags, Transaction} import fr.acinq.eclair.blockchain.bitcoind.BitcoindService import fr.acinq.eclair.blockchain.bitcoind.rpc.ExtendedBitcoinClient @@ -39,11 +39,11 @@ import fr.acinq.eclair.payment.PaymentLifecycle.{State => _, _} import fr.acinq.eclair.payment.{LocalPaymentHandler, PaymentRequest} import fr.acinq.eclair.router.Graph.WeightRatios import fr.acinq.eclair.router.Router.ROUTE_MAX_LENGTH -import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, ChannelDesc, RouteParams} +import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, ChannelDesc, PublicChannel, RouteParams} import fr.acinq.eclair.transactions.Transactions import fr.acinq.eclair.transactions.Transactions.{HtlcSuccessTx, HtlcTimeoutTx} import fr.acinq.eclair.wire._ -import fr.acinq.eclair.{Globals, Kit, Setup, randomBytes32} +import fr.acinq.eclair.{Globals, Kit, Setup, ShortChannelId, randomBytes32} import grizzled.slf4j.Logging import org.json4s.JsonAST.JValue import org.json4s.{DefaultFormats, JString} @@ -295,12 +295,14 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService val ps = sender.expectMsgType[PaymentSucceeded](5 seconds) assert(ps.id == paymentId) + def updateFor(n: PublicKey, pc: PublicChannel): Option[ChannelUpdate] = if (n == pc.ann.nodeId1) pc.update_1_opt else if (n == pc.ann.nodeId2) pc.update_2_opt else throw new IllegalArgumentException("this node is unrelated to this channel") + awaitCond({ // in the meantime, the router will have updated its state - sender.send(nodes("A").router, 'updatesMap) + sender.send(nodes("A").router, 'channelsMap) // we then put everything back like before by asking B to refresh its channel update (this will override the one we created) - val update = sender.expectMsgType[Map[ChannelDesc, ChannelUpdate]](10 seconds).apply(ChannelDesc(channelUpdateBC.shortChannelId, nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId)) - update == channelUpdateBC + val u_opt = updateFor(nodes("B").nodeParams.nodeId, sender.expectMsgType[Map[ShortChannelId, PublicChannel]](10 seconds).apply(channelUpdateBC.shortChannelId)) + u_opt.contains(channelUpdateBC) }, max = 30 seconds, interval = 1 seconds) // first let's wait 3 seconds to make sure the timestamp of the new channel_update will be strictly greater than the former @@ -313,8 +315,8 @@ class IntegrationSpec extends TestKit(ActorSystem("test")) with BitcoindService assert(channelUpdateBC_new.timestamp > channelUpdateBC.timestamp) assert(channelUpdateBC_new.cltvExpiryDelta == nodes("B").nodeParams.expiryDeltaBlocks) awaitCond({ - sender.send(nodes("A").router, 'updatesMap) - val u = sender.expectMsgType[Map[ChannelDesc, ChannelUpdate]].apply(ChannelDesc(channelUpdateBC.shortChannelId, nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId)) + sender.send(nodes("A").router, 'channelsMap) + val u = updateFor(nodes("B").nodeParams.nodeId, sender.expectMsgType[Map[ShortChannelId, PublicChannel]](10 seconds).apply(channelUpdateBC.shortChannelId)).get u.cltvExpiryDelta == nodes("B").nodeParams.expiryDeltaBlocks }, max = 30 seconds, interval = 1 second) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index e9f0414147..dab8376e51 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -45,9 +45,9 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { val fakeIPAddress = NodeAddress.fromParts("1.2.3.4", 42000).get val shortChannelIds = RoutingSyncSpec.shortChannelIds.take(100) val fakeRoutingInfo = shortChannelIds.map(makeFakeRoutingInfo) - val channels = fakeRoutingInfo.map(_._1).toList - val updates = (fakeRoutingInfo.map(_._2) ++ fakeRoutingInfo.map(_._3)).toList - val nodes = (fakeRoutingInfo.map(_._4) ++ fakeRoutingInfo.map(_._5)).toList + val channels = fakeRoutingInfo.map(_._1.ann).toList + val updates = (fakeRoutingInfo.flatMap(_._1.update_1_opt) ++ fakeRoutingInfo.flatMap(_._1.update_2_opt)).toList + val nodes = (fakeRoutingInfo.map(_._1.ann.nodeId1) ++ fakeRoutingInfo.map(_._1.ann.nodeId2)).map(RoutingSyncSpec.makeFakeNodeAnnouncement).toList case class FixtureParam(remoteNodeId: PublicKey, authenticator: TestProbe, watcher: TestProbe, router: TestProbe, relayer: TestProbe, connection: TestProbe, transport: TestProbe, peer: TestFSMRef[Peer.State, Peer.Data, Peer]) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala index 7428007a8f..ca066c3e71 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.router +import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ import fr.acinq.eclair.{MilliSatoshi, randomKey} import org.scalatest.FunSuite @@ -89,39 +90,33 @@ class ChannelRangeQueriesSpec extends FunSuite { val ef = RouteCalculationSpec.makeChannel(167514L, e, f) val channels = SortedMap( - ab.shortChannelId -> ab, - cd.shortChannelId -> cd - ) - - val updates = Map( - ab1 -> uab1, - ab2 -> uab2, - cd1 -> ucd1 + ab.shortChannelId -> PublicChannel(ab, ByteVector32.Zeroes, Satoshi(0), Some(uab1), Some(uab2)), + cd.shortChannelId -> PublicChannel(cd, ByteVector32.Zeroes, Satoshi(0), Some(ucd1), None) ) import fr.acinq.eclair.wire.QueryShortChannelIdsTlv.QueryFlagType._ - assert(Router.getChannelDigestInfo(channels, updates)(ab.shortChannelId) == (Timestamps(now, now), Checksums(1697591108L, 1697591108L))) + assert(Router.getChannelDigestInfo(channels)(ab.shortChannelId) == (Timestamps(now, now), Checksums(1697591108L, 3692323747L))) // no extended info but we know the channel: we ask for the updates - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels)(ab.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) + assert(Router.computeFlag(channels)(ab.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) // same checksums, newer timestamps: we don't ask anything - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(1697591108L, 1697591108L)), true) === 0) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(1697591108L, 3692323747L)), true) === 0) // different checksums, newer timestamps: we ask for the updates - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now)), Some(Checksums(154654604, 1697591108L)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now, now + 1)), Some(Checksums(1697591108L, 45664546)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(154654604, 45664546 + 6)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2| INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now)), Some(Checksums(154654604, 3692323747L)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now, now + 1)), Some(Checksums(1697591108L, 45664546)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(154654604, 45664546 + 6)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2| INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) // different checksums, older timestamps: we don't ask anything - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now - 1, now)), Some(Checksums(154654604, 1697591108L)), true) === 0) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now, now - 1)), Some(Checksums(1697591108L, 45664546)), true) === 0) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now - 1, now - 1)), Some(Checksums(154654604, 45664546)), true) === 0) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now - 1, now)), Some(Checksums(154654604, 3692323747L)), true) === 0) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now, now - 1)), Some(Checksums(1697591108L, 45664546)), true) === 0) + assert(Router.computeFlag(channels)(ab.shortChannelId, Some(Timestamps(now - 1, now - 1)), Some(Checksums(154654604, 45664546)), true) === 0) // missing channel update: we ask for it - assert(Router.computeFlag(channels, updates)(cd.shortChannelId, Some(Timestamps(now, now)), Some(Checksums(3297511804L, 3297511804L)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels)(cd.shortChannelId, Some(Timestamps(now, now)), Some(Checksums(3297511804L, 3297511804L)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) // unknown channel: we ask everything - assert(Router.computeFlag(channels, updates)(ef.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) - assert(Router.computeFlag(channels, updates)(ef.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels)(ef.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) + assert(Router.computeFlag(channels)(ef.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala index 9730365850..ea06cd1071 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala @@ -54,7 +54,7 @@ class GraphSpec extends FunSuite { makeUpdate(6L, b, e, MilliSatoshi(0), 0) ) - DirectedGraph.makeGraph(updates.toMap) + DirectedGraph().addEdges(updates) } test("instantiate a graph, with vertices and then add edges") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 840bbea8d0..9c63989d71 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} +import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair.payment.PaymentRequest.ExtraHop import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} @@ -28,6 +28,7 @@ import fr.acinq.eclair.{CltvExpiryDelta, Globals, MilliSatoshi, ShortChannelId, import org.scalatest.FunSuite import scodec.bits._ +import scala.collection.immutable.SortedMap import scala.util.{Failure, Success} /** @@ -445,13 +446,15 @@ class RouteCalculationSpec extends FunSuite { val extraHops = extraHop1 :: extraHop2 :: extraHop3 :: extraHop4 :: Nil - val fakeUpdates = Router.toFakeUpdates(extraHops, e) + val fakeUpdates: Map[ShortChannelId, ExtraHop] = Router.toAssistedChannels(extraHops, e).map { case (shortChannelId, assistedChannel) => + (shortChannelId, assistedChannel.extraHop) + } assert(fakeUpdates == Map( - ChannelDesc(extraHop1.shortChannelId, a, b) -> Router.toFakeUpdate(extraHop1), - ChannelDesc(extraHop2.shortChannelId, b, c) -> Router.toFakeUpdate(extraHop2), - ChannelDesc(extraHop3.shortChannelId, c, d) -> Router.toFakeUpdate(extraHop3), - ChannelDesc(extraHop4.shortChannelId, d, e) -> Router.toFakeUpdate(extraHop4) + extraHop1.shortChannelId -> extraHop1, + extraHop2.shortChannelId -> extraHop2, + extraHop3.shortChannelId -> extraHop3, + extraHop4.shortChannelId -> extraHop4 )) } @@ -540,8 +543,8 @@ class RouteCalculationSpec extends FunSuite { ShortChannelId(6L) -> makeChannel(6L, f, h), ShortChannelId(7L) -> makeChannel(7L, h, i), ShortChannelId(8L) -> makeChannel(8L, i, j) - ) + val updates = List( makeUpdate(1L, a, b, MilliSatoshi(10), 10), makeUpdate(2L, b, c, MilliSatoshi(10), 10), @@ -554,14 +557,20 @@ class RouteCalculationSpec extends FunSuite { makeUpdate(8L, i, j, MilliSatoshi(10), 10) ).toMap - val ignored = Router.getIgnoredChannelDesc(updates, ignoreNodes = Set(c, j, randomKey.publicKey)) + val publicChannels = channels.map { case (shortChannelId, announcement) => + val (_, update) = updates.find{ case (d, u) => d.shortChannelId == shortChannelId}.get + val (update_1_opt, update_2_opt) = if (Announcements.isNode1(update.channelFlags)) (Some(update), None) else (None, Some(update)) + val pc = PublicChannel(announcement, ByteVector32.Zeroes, Satoshi(1000), update_1_opt, update_2_opt) + (shortChannelId, pc) + } - assert(ignored.toSet === Set( - ChannelDesc(ShortChannelId(2L), b, c), - ChannelDesc(ShortChannelId(2L), c, b), - ChannelDesc(ShortChannelId(3L), c, d), - ChannelDesc(ShortChannelId(8L), i, j) - )) + + val ignored = Router.getIgnoredChannelDesc(publicChannels, ignoreNodes = Set(c, j, randomKey.publicKey)) + + assert(ignored.toSet.contains(ChannelDesc(ShortChannelId(2L), b, c))) + assert(ignored.toSet.contains(ChannelDesc(ShortChannelId(2L), c, b))) + assert(ignored.toSet.contains(ChannelDesc(ShortChannelId(3L), c, d))) + assert(ignored.toSet.contains(ChannelDesc(ShortChannelId(8L), i, j))) } test("limit routes to 20 hops") { @@ -702,11 +711,11 @@ class RouteCalculationSpec extends FunSuite { makeUpdate(5L, e, f, MilliSatoshi(1), 0), makeUpdate(6L, b, c, MilliSatoshi(1), 0), makeUpdate(7L, c, f, MilliSatoshi(1), 0) - ).toMap + ) - val graph = DirectedGraph.makeGraph(edges) + val graph = DirectedGraph().addEdges(edges) - val fourShortestPaths = Graph.yenKshortestPaths(graph, d, f, DEFAULT_AMOUNT_MSAT, Set.empty, Set.empty, pathsToFind = 4, None, 0, noopBoundaries) + val fourShortestPaths = Graph.yenKshortestPaths(graph, d, f, DEFAULT_AMOUNT_MSAT, Set.empty, Set.empty, Set.empty, pathsToFind = 4, None, 0, noopBoundaries) assert(fourShortestPaths.size === 4) assert(hops2Ids(fourShortestPaths(0).path.map(graphEdgeToHop)) === 2 :: 5 :: Nil) // D -> E -> F @@ -740,7 +749,7 @@ class RouteCalculationSpec extends FunSuite { val graph = DirectedGraph().addEdges(edges) - val twoShortestPaths = Graph.yenKshortestPaths(graph, c, h, DEFAULT_AMOUNT_MSAT, Set.empty, Set.empty, pathsToFind = 2, None, 0, noopBoundaries) + val twoShortestPaths = Graph.yenKshortestPaths(graph, c, h, DEFAULT_AMOUNT_MSAT, Set.empty, Set.empty, Set.empty, pathsToFind = 2, None, 0, noopBoundaries) assert(twoShortestPaths.size === 2) val shortest = twoShortestPaths(0) @@ -774,7 +783,7 @@ class RouteCalculationSpec extends FunSuite { val graph = DirectedGraph().addEdges(edges) //we ask for 3 shortest paths but only 2 can be found - val foundPaths = Graph.yenKshortestPaths(graph, a, f, DEFAULT_AMOUNT_MSAT, Set.empty, Set.empty, pathsToFind = 3, None, 0, noopBoundaries) + val foundPaths = Graph.yenKshortestPaths(graph, a, f, DEFAULT_AMOUNT_MSAT, Set.empty, Set.empty, Set.empty, pathsToFind = 3, None, 0, noopBoundaries) assert(foundPaths.size === 2) assert(hops2Ids(foundPaths(0).path.map(graphEdgeToHop)) === 1 :: 2 :: 3 :: Nil) // A -> B -> C -> F @@ -920,14 +929,29 @@ class RouteCalculationSpec extends FunSuite { // This test have a channel (542280x2156x0) that according to heuristics is very convenient but actually useless to reach the target, // then if the cost function is not monotonic the path-finding breaks because the result path contains a loop. - val updates = List( - ChannelDesc(ShortChannelId("565643x1216x0"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), PublicKey(hex"024655b768ef40951b20053a5c4b951606d4d86085d51238f2c67c7dec29c792ca")) -> ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565643x1216x0"), 0, 1.toByte, 1.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(0), feeBaseMsat = MilliSatoshi(1000), 100, Some(MilliSatoshi(15000000000L))), - ChannelDesc(ShortChannelId("565643x1216x0"), PublicKey(hex"024655b768ef40951b20053a5c4b951606d4d86085d51238f2c67c7dec29c792ca"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f")) -> ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565643x1216x0"), 0, 1.toByte, 0.toByte, CltvExpiryDelta(14), htlcMinimumMsat = MilliSatoshi(1), MilliSatoshi(1000), 10, Some(MilliSatoshi(4294967295L))), - ChannelDesc(ShortChannelId("542280x2156x0"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), PublicKey(hex"03cb7983dc247f9f81a0fa2dfa3ce1c255365f7279c8dd143e086ca333df10e278")) -> ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("542280x2156x0"), 0, 1.toByte, 1.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1000), feeBaseMsat = MilliSatoshi(1000), 100, Some(MilliSatoshi(16777000000L))), - ChannelDesc(ShortChannelId("542280x2156x0"), PublicKey(hex"03cb7983dc247f9f81a0fa2dfa3ce1c255365f7279c8dd143e086ca333df10e278"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f")) -> ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("542280x2156x0"), 0, 1.toByte, 0.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1), MilliSatoshi(667), 1, Some(MilliSatoshi(16777000000L))), - ChannelDesc(ShortChannelId("565779x2711x0"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), PublicKey(hex"036d65409c41ab7380a43448f257809e7496b52bf92057c09c4f300cbd61c50d96")) -> ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565779x2711x0"), 0, 1.toByte, 3.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1), MilliSatoshi(1000), 100, Some(MilliSatoshi(230000000L))), - ChannelDesc(ShortChannelId("565779x2711x0"), PublicKey(hex"036d65409c41ab7380a43448f257809e7496b52bf92057c09c4f300cbd61c50d96"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f")) -> ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565779x2711x0"), 0, 1.toByte, 0.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1), MilliSatoshi(1000), 100, Some(MilliSatoshi(230000000L))) - ).toMap + val updates = SortedMap( + ShortChannelId("565643x1216x0") -> PublicChannel( + ann = makeChannel(ShortChannelId("565643x1216x0").toLong, PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), PublicKey(hex"024655b768ef40951b20053a5c4b951606d4d86085d51238f2c67c7dec29c792ca")), + fundingTxid = ByteVector32.Zeroes, + capacity = Satoshi(0), + update_1_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565643x1216x0"), 0, 1.toByte, 0.toByte, CltvExpiryDelta(14), htlcMinimumMsat = MilliSatoshi(1), feeBaseMsat = MilliSatoshi(1000), 10, Some(MilliSatoshi(4294967295L)))), + update_2_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565643x1216x0"), 0, 1.toByte, 1.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(0), feeBaseMsat = MilliSatoshi(1000), 100, Some(MilliSatoshi(15000000000L)))) + ), + ShortChannelId("542280x2156x0") -> PublicChannel( + ann = makeChannel(ShortChannelId("542280x2156x0").toLong, PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), PublicKey(hex"03cb7983dc247f9f81a0fa2dfa3ce1c255365f7279c8dd143e086ca333df10e278")), + fundingTxid = ByteVector32.Zeroes, + capacity = Satoshi(0), + update_1_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("542280x2156x0"), 0, 1.toByte, 0.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1000), feeBaseMsat = MilliSatoshi(1000), 100, Some(MilliSatoshi(16777000000L)))), + update_2_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("542280x2156x0"), 0, 1.toByte, 1.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1), feeBaseMsat = MilliSatoshi(667), 1, Some(MilliSatoshi(16777000000L)))) + ), + ShortChannelId("565779x2711x0") -> PublicChannel( + ann = makeChannel(ShortChannelId("565779x2711x0").toLong, PublicKey(hex"036d65409c41ab7380a43448f257809e7496b52bf92057c09c4f300cbd61c50d96"), PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f")), + fundingTxid = ByteVector32.Zeroes, + capacity = Satoshi(0), + update_1_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565779x2711x0"), 0, 1.toByte, 0.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1), feeBaseMsat = MilliSatoshi(1000), 100, Some(MilliSatoshi(230000000L)))), + update_2_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId("565779x2711x0"), 0, 1.toByte, 3.toByte, CltvExpiryDelta(144), htlcMinimumMsat = MilliSatoshi(1), feeBaseMsat = MilliSatoshi(1000), 100, Some(MilliSatoshi(230000000L)))) + ) + ) val g = DirectedGraph.makeGraph(updates) @@ -939,7 +963,7 @@ class RouteCalculationSpec extends FunSuite { val amount = MilliSatoshi(351000) Globals.blockCount.set(567634) // simulate mainnet block for heuristic - val Success(route) = Router.findRoute(g, thisNode, targetNode, amount, 1, Set.empty, Set.empty, params) + val Success(route) = Router.findRoute(g, thisNode, targetNode, amount, 1, Set.empty, Set.empty, Set.empty, params) assert(route.size == 2) assert(route.last.nextNodeId == targetNode) @@ -975,7 +999,7 @@ object RouteCalculationSpec { case Some(_) => 1 case None => 0 }, - channelFlags = 0, + channelFlags = if (Announcements.isNode1(nodeId1, nodeId2)) 0 else 1, cltvExpiryDelta = cltvDelta, htlcMinimumMsat = minHtlc, feeBaseMsat = feeBase, @@ -983,7 +1007,7 @@ object RouteCalculationSpec { htlcMaximumMsat = maxHtlc ) - def makeGraph(updates: Map[ChannelDesc, ChannelUpdate]) = DirectedGraph.makeGraph(updates) + def makeGraph(updates: Map[ChannelDesc, ChannelUpdate]) = DirectedGraph().addEdges(updates.toSeq) def hops2Ids(route: Seq[Hop]) = route.map(hop => hop.lastUpdate.shortChannelId.toLong) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index 23bae1499e..b093dcd8bb 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -233,7 +233,7 @@ class RouterSpec extends BaseRouterSpec { val state = sender.expectMsgType[RoutingState] assert(state.channels.size == 4) assert(state.nodes.size == 6) - assert(state.updates.size == 8) + assert(state.channels.flatMap(c => c.update_1_opt.toSeq ++ c.update_2_opt.toSeq).size == 8) } test("given a pre-computed route add the proper channel updates") { fixture => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index f0f7cfa1ec..f3d9d747cb 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.router import akka.actor.{Actor, ActorSystem, Props} import akka.testkit.{TestFSMRef, TestKit, TestProbe} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.bitcoin.{Block, Satoshi, Script, Transaction, TxIn, TxOut} +import fr.acinq.bitcoin.{Block, ByteVector32, Satoshi, Script, Transaction, TxIn, TxOut} import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult} @@ -41,10 +41,10 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { import RoutingSyncSpec._ - val fakeRoutingInfo: TreeMap[ShortChannelId, (ChannelAnnouncement, ChannelUpdate, ChannelUpdate, NodeAnnouncement, NodeAnnouncement)] = RoutingSyncSpec + val fakeRoutingInfo: TreeMap[ShortChannelId, (PublicChannel, NodeAnnouncement, NodeAnnouncement)] = RoutingSyncSpec .shortChannelIds .take(4567) - .foldLeft(TreeMap.empty[ShortChannelId, (ChannelAnnouncement, ChannelUpdate, ChannelUpdate, NodeAnnouncement, NodeAnnouncement)]) { + .foldLeft(TreeMap.empty[ShortChannelId, (PublicChannel, NodeAnnouncement, NodeAnnouncement)]) { case (m, shortChannelId) => m + (shortChannelId -> makeFakeRoutingInfo(shortChannelId)) } @@ -119,6 +119,10 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { SyncResult(rcrs, queries, channels, updates, nodes) } + def countUpdates(channels: Map[ShortChannelId, PublicChannel]) = channels.values.foldLeft(0) { + case (count, pc) => count + pc.update_1_opt.map(_ => 1).getOrElse(0) + pc.update_2_opt.map(_ => 1).getOrElse(0) + } + test("sync with standard channel queries") { val watcher = system.actorOf(Props(new YesWatcher())) val alice = TestFSMRef(new Router(Alice.nodeParams, watcher)) @@ -130,46 +134,42 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { // tell alice to sync with bob assert(BasicSyncResult(ranges = 1, queries = 0, channels = 0, updates = 0, nodes = 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels) - awaitCond(alice.stateData.updates === bob.stateData.updates) awaitCond(alice.stateData.nodes === bob.stateData.nodes) // add some channels and updates to bob and resync fakeRoutingInfo.take(40).values.foreach { - case (ca, cu1, cu2, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, ca)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu1)) + case (pc, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) // we don't send channel_update #2 sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) } - awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 40) + awaitCond(bob.stateData.channels.size === 40 && countUpdates(bob.stateData.channels) === 40) assert(BasicSyncResult(ranges = 1, queries = 1, channels = 40, updates = 40, nodes = 80) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels) - awaitCond(alice.stateData.updates === bob.stateData.updates) // add some updates to bob and resync fakeRoutingInfo.take(40).values.foreach { - case (ca, cu1, cu2, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu2)) + case (pc, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) } - awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 80) + awaitCond(bob.stateData.channels.size === 40 && countUpdates(bob.stateData.channels) === 80) assert(BasicSyncResult(ranges = 1, queries = 1, channels = 40, updates = 80, nodes = 80) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels) - awaitCond(alice.stateData.updates === bob.stateData.updates) // add everything (duplicates will be ignored) fakeRoutingInfo.values.foreach { - case (c, u1, u2, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u1)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u2)) + case (pc, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) } - awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && bob.stateData.updates.size === 2 * fakeRoutingInfo.size, max = 60 seconds) + awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && countUpdates(bob.stateData.channels) === 2 * fakeRoutingInfo.size, max = 60 seconds) assert(BasicSyncResult(ranges = 2, queries = 46, channels = fakeRoutingInfo.size, updates = 2 * fakeRoutingInfo.size, nodes = 2 * fakeRoutingInfo.size) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) - awaitCond(alice.stateData.updates === bob.stateData.updates) } def syncWithExtendedQueries(requestNodeAnnouncements: Boolean) = { @@ -183,50 +183,46 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { // tell alice to sync with bob assert(BasicSyncResult(ranges = 1, queries = 0, channels = 0, updates = 0, nodes = 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels) - awaitCond(alice.stateData.updates === bob.stateData.updates) // add some channels and updates to bob and resync fakeRoutingInfo.take(40).values.foreach { - case (ca, cu1, cu2, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, ca)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu1)) + case (pc, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) // we don't send channel_update #2 sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) } - awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 40) + awaitCond(bob.stateData.channels.size === 40 && countUpdates(bob.stateData.channels) === 40) assert(BasicSyncResult(ranges = 1, queries = 1, channels = 40, updates = 40, nodes = if (requestNodeAnnouncements) 80 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) - awaitCond(alice.stateData.updates === bob.stateData.updates) if (requestNodeAnnouncements) awaitCond(alice.stateData.nodes === bob.stateData.nodes) // add some updates to bob and resync fakeRoutingInfo.take(40).values.foreach { - case (ca, cu1, cu2, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu2)) + case (pc, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) } - awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 80) + awaitCond(bob.stateData.channels.size === 40 && countUpdates(bob.stateData.channels) === 80) assert(BasicSyncResult(ranges = 1, queries = 1, channels = 0, updates = 40, nodes = if (requestNodeAnnouncements) 80 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) - awaitCond(alice.stateData.updates === bob.stateData.updates) // add everything (duplicates will be ignored) fakeRoutingInfo.values.foreach { - case (c, u1, u2, na1, na2) => - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u1)) - sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u2)) + case (pc, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.ann)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_1_opt.get)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, pc.update_2_opt.get)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) } - awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && bob.stateData.updates.size === 2 * fakeRoutingInfo.size, max = 60 seconds) + awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && countUpdates(bob.stateData.channels) === 2 * fakeRoutingInfo.size, max = 60 seconds) assert(BasicSyncResult(ranges = 2, queries = 46, channels = fakeRoutingInfo.size - 40, updates = 2 * (fakeRoutingInfo.size - 40), nodes = if (requestNodeAnnouncements) 2 * (fakeRoutingInfo.size - 40) else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) - awaitCond(alice.stateData.updates === bob.stateData.updates) // bump random channel_updates def touchUpdate(shortChannelId: Int, side: Boolean) = { - val (c, u1, u2, _, _) = fakeRoutingInfo.values.toList(shortChannelId) + val PublicChannel(c, _, _, Some(u1), Some(u2)) = fakeRoutingInfo.values.toList(shortChannelId)._1 makeNewerChannelUpdate(c, if (side) u1 else u2) } @@ -234,7 +230,6 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { bumpedUpdates.foreach(c => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c))) assert(BasicSyncResult(ranges = 2, queries = 2, channels = 0, updates = bumpedUpdates.size, nodes = if (requestNodeAnnouncements) 20 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) - awaitCond(alice.stateData.updates === bob.stateData.updates) if (requestNodeAnnouncements) awaitCond(alice.stateData.nodes === bob.stateData.nodes) } @@ -311,7 +306,7 @@ object RoutingSyncSpec { val unused = randomKey - def makeFakeRoutingInfo(shortChannelId: ShortChannelId): (ChannelAnnouncement, ChannelUpdate, ChannelUpdate, NodeAnnouncement, NodeAnnouncement) = { + def makeFakeRoutingInfo(shortChannelId: ShortChannelId): (PublicChannel, NodeAnnouncement, NodeAnnouncement) = { val timestamp = Platform.currentTime / 1000 val (priv1, priv2) = { val (priv_a, priv_b) = (randomKey, randomKey) @@ -326,7 +321,8 @@ object RoutingSyncSpec { val channelUpdate_21 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv2, priv1.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), MilliSatoshi(0), feeBaseMsat = MilliSatoshi(766000), feeProportionalMillionths = 10, MilliSatoshi(500000000L), timestamp = timestamp) val nodeAnnouncement_1 = makeNodeAnnouncement(priv1, "", Color(0, 0, 0), List()) val nodeAnnouncement_2 = makeNodeAnnouncement(priv2, "", Color(0, 0, 0), List()) - (channelAnn_12, channelUpdate_12, channelUpdate_21, nodeAnnouncement_1, nodeAnnouncement_2) + val publicChannel = PublicChannel(channelAnn_12, ByteVector32.Zeroes, Satoshi(0), Some(channelUpdate_12), Some(channelUpdate_21)) + (publicChannel, nodeAnnouncement_1, nodeAnnouncement_2) } def makeNewerChannelUpdate(channelAnnouncement: ChannelAnnouncement, channelUpdate: ChannelUpdate): ChannelUpdate = { @@ -337,4 +333,9 @@ object RoutingSyncSpec { channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, channelUpdate.htlcMinimumMsat, Announcements.isEnabled(channelUpdate.channelFlags), channelUpdate.timestamp + 5000) } + + def makeFakeNodeAnnouncement(nodeId: PublicKey): NodeAnnouncement = { + val priv = pub2priv(nodeId) + makeNodeAnnouncement(priv, "", Color(0, 0, 0), List()) + } } diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/AboutController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/AboutController.scala index 39c449d2f9..1295344921 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/AboutController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/AboutController.scala @@ -16,12 +16,11 @@ package fr.acinq.eclair.gui.controllers +import grizzled.slf4j.Logging import javafx.application.HostServices import javafx.fxml.FXML import javafx.scene.text.Text -import grizzled.slf4j.Logging - /** * Created by DPA on 28/09/2016. */ diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ChannelPaneController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ChannelPaneController.scala index c4691f0148..dc00192fec 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ChannelPaneController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ChannelPaneController.scala @@ -22,16 +22,16 @@ import fr.acinq.eclair.MilliSatoshi import fr.acinq.eclair.CoinUtils import fr.acinq.eclair.channel.{CMD_CLOSE, CMD_FORCECLOSE, Commitments} import fr.acinq.eclair.gui.FxApp +import fr.acinq.eclair.gui.utils.{ContextMenuUtils, CopyAction} +import grizzled.slf4j.Logging import javafx.application.Platform import javafx.beans.value.{ChangeListener, ObservableValue} +import javafx.event.{ActionEvent, EventHandler} import javafx.fxml.FXML +import javafx.scene.control.Alert.AlertType import javafx.scene.control._ import javafx.scene.input.{ContextMenuEvent, MouseEvent} import javafx.scene.layout.VBox -import fr.acinq.eclair.gui.utils.{ContextMenuUtils, CopyAction} -import grizzled.slf4j.Logging -import javafx.event.{ActionEvent, EventHandler} -import javafx.scene.control.Alert.AlertType /** * Created by DPA on 23/09/2016. diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NodeInfoController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NodeInfoController.scala index 5cd4fc8b9d..639653c24f 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NodeInfoController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NodeInfoController.scala @@ -16,16 +16,15 @@ package fr.acinq.eclair.gui.controllers +import fr.acinq.eclair.gui.Handlers +import fr.acinq.eclair.gui.utils.{ContextMenuUtils, QRCodeUtils} +import grizzled.slf4j.Logging import javafx.event.ActionEvent import javafx.fxml.FXML import javafx.scene.control._ import javafx.scene.image.ImageView import javafx.stage.Stage -import fr.acinq.eclair.gui.Handlers -import fr.acinq.eclair.gui.utils.{ContextMenuUtils, QRCodeUtils} -import grizzled.slf4j.Logging - import scala.util.{Failure, Success, Try} class NodeInfoController(val address: String, val handlers: Handlers, val stage: Stage) extends Logging { diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NotificationsController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NotificationsController.scala index ae12cf7af7..bc09998e03 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NotificationsController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/NotificationsController.scala @@ -16,6 +16,8 @@ package fr.acinq.eclair.gui.controllers +import fr.acinq.eclair.gui.utils.ContextMenuUtils +import grizzled.slf4j.Logging import javafx.animation._ import javafx.application.Platform import javafx.event.{ActionEvent, EventHandler} @@ -25,9 +27,6 @@ import javafx.scene.image.Image import javafx.scene.layout.{GridPane, VBox} import javafx.util.Duration -import fr.acinq.eclair.gui.utils.ContextMenuUtils -import grizzled.slf4j.Logging - sealed trait NotificationType case object NOTIFICATION_NONE extends NotificationType diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ReceivePaymentController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ReceivePaymentController.scala index 2abe281afb..07140a3234 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ReceivePaymentController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/ReceivePaymentController.scala @@ -16,6 +16,10 @@ package fr.acinq.eclair.gui.controllers +import fr.acinq.eclair.{CoinUtils, MilliSatoshi} +import fr.acinq.eclair.gui.utils._ +import fr.acinq.eclair.gui.{FxApp, Handlers} +import grizzled.slf4j.Logging import javafx.application.Platform import javafx.event.ActionEvent import javafx.fxml.FXML @@ -24,13 +28,6 @@ import javafx.scene.image.{ImageView, WritableImage} import javafx.scene.layout.GridPane import javafx.stage.Stage -import fr.acinq.eclair.MilliSatoshi -import fr.acinq.eclair.CoinUtils -import fr.acinq.eclair.gui.{FxApp, Handlers} -import fr.acinq.eclair.gui.utils._ -import fr.acinq.eclair.payment.PaymentRequest -import grizzled.slf4j.Logging - import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success, Try} diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SendPaymentController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SendPaymentController.scala index 5cce7374c9..0b0ddd0505 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SendPaymentController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SendPaymentController.scala @@ -16,6 +16,10 @@ package fr.acinq.eclair.gui.controllers +import fr.acinq.eclair.CoinUtils +import fr.acinq.eclair.gui.{FxApp, Handlers} +import fr.acinq.eclair.payment.PaymentRequest +import grizzled.slf4j.Logging import javafx.beans.value.{ChangeListener, ObservableValue} import javafx.event.{ActionEvent, EventHandler} import javafx.fxml.FXML @@ -24,11 +28,6 @@ import javafx.scene.input.KeyCode.{ENTER, TAB} import javafx.scene.input.KeyEvent import javafx.stage.Stage -import fr.acinq.eclair.CoinUtils -import fr.acinq.eclair.gui.{FxApp, Handlers} -import fr.acinq.eclair.payment.PaymentRequest -import grizzled.slf4j.Logging - import scala.util.{Failure, Success, Try} /** diff --git a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SplashController.scala b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SplashController.scala index 17e001ebb3..ca32d62662 100644 --- a/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SplashController.scala +++ b/eclair-node-gui/src/main/scala/fr/acinq/eclair/gui/controllers/SplashController.scala @@ -16,6 +16,7 @@ package fr.acinq.eclair.gui.controllers +import grizzled.slf4j.Logging import javafx.animation._ import javafx.application.HostServices import javafx.fxml.FXML @@ -24,8 +25,6 @@ import javafx.scene.image.ImageView import javafx.scene.layout.{Pane, VBox} import javafx.util.Duration -import grizzled.slf4j.Logging - /** * Created by DPA on 22/09/2016. */