Skip to content

Commit

Permalink
Rework router data structures (#902)
Browse files Browse the repository at this point in the history
Instead of using two separate maps (for channels and channel_updates), we now use a single map, which groups channel+channel_updates. This is also true for data storage, resulting in the removal of the channel_updates table.
  • Loading branch information
pm47 authored Aug 28, 2019
1 parent 2f42538 commit 8f7a415
Show file tree
Hide file tree
Showing 25 changed files with 486 additions and 401 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object DBCompatChecker extends Logging {
* @param nodeParams
*/
def checkNetworkDBCompatibility(nodeParams: NodeParams): Unit =
Try(nodeParams.db.network.listChannels(), nodeParams.db.network.listNodes(), nodeParams.db.network.listChannelUpdates()) match {
Try(nodeParams.db.network.listChannels(), nodeParams.db.network.listNodes()) match {
case Success(_) => {}
case Failure(_) => throw IncompatibleNetworkDBException
}
Expand Down
20 changes: 7 additions & 13 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package fr.acinq.eclair.db
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.router.PublicChannel
import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}

import scala.collection.immutable.SortedMap

trait NetworkDb {

def addNode(n: NodeAnnouncement)
Expand All @@ -35,22 +38,13 @@ trait NetworkDb {

def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi)

def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Seq(shortChannelId))

/**
* This method removes channel announcements and associated channel updates for a list of channel ids
*
* @param shortChannelIds list of short channel ids
*/
def removeChannels(shortChannelIds: Iterable[ShortChannelId])
def updateChannel(u: ChannelUpdate)

def listChannels(): Map[ChannelAnnouncement, (ByteVector32, Satoshi)]
def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Set(shortChannelId))

def addChannelUpdate(u: ChannelUpdate)

def updateChannelUpdate(u: ChannelUpdate)
def removeChannels(shortChannelIds: Iterable[ShortChannelId])

def listChannelUpdates(): Seq[ChannelUpdate]
def listChannels(): SortedMap[ShortChannelId, PublicChannel]

def addToPruned(shortChannelIds: Iterable[ShortChannelId]): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,39 @@ import java.sql.Connection
import fr.acinq.bitcoin.{ByteVector32, Crypto, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.db.NetworkDb
import fr.acinq.eclair.router.Announcements
import fr.acinq.eclair.router.PublicChannel
import fr.acinq.eclair.wire.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
import fr.acinq.eclair.wire.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
import scodec.bits.BitVector
import grizzled.slf4j.Logging

class SqliteNetworkDb(sqlite: Connection) extends NetworkDb {
import scala.collection.immutable.SortedMap

class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging {

import SqliteUtils._
import SqliteUtils.ExtendedResultSet._

val DB_NAME = "network"
val CURRENT_VERSION = 1
val CURRENT_VERSION = 2

using(sqlite.createStatement()) { statement =>
require(getVersion(statement, DB_NAME, CURRENT_VERSION) == CURRENT_VERSION, s"incompatible version of $DB_NAME DB found") // there is only one version currently deployed
statement.execute("PRAGMA foreign_keys = ON")
getVersion(statement, DB_NAME, CURRENT_VERSION) match {
case 1 =>
// channel_update are cheap to retrieve, so let's just wipe them out and they'll get resynced
statement.execute("PRAGMA foreign_keys = ON")
logger.warn("migrating network db version 1->2")
statement.executeUpdate("ALTER TABLE channels RENAME COLUMN data TO channel_announcement")
statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_update_1 BLOB NULL")
statement.executeUpdate("ALTER TABLE channels ADD COLUMN channel_update_2 BLOB NULL")
statement.executeUpdate("DROP TABLE channel_updates")
statement.execute("PRAGMA foreign_keys = OFF")
setVersion(statement, DB_NAME, CURRENT_VERSION)
logger.warn("migration complete")
case 2 => () // nothing to do
case unknown => throw new IllegalArgumentException(s"unknown version $unknown for network db")
}
statement.executeUpdate("CREATE TABLE IF NOT EXISTS nodes (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, data BLOB NOT NULL, capacity_sat INTEGER NOT NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_updates (short_channel_id INTEGER NOT NULL, node_flag INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(short_channel_id, node_flag), FOREIGN KEY(short_channel_id) REFERENCES channels(short_channel_id))")
statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_updates_idx ON channel_updates(short_channel_id)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS channels (short_channel_id INTEGER NOT NULL PRIMARY KEY, txid STRING NOT NULL, channel_announcement BLOB NOT NULL, capacity_sat INTEGER NOT NULL, channel_update_1 BLOB NULL, channel_update_2 BLOB NULL)")
statement.executeUpdate("CREATE TABLE IF NOT EXISTS pruned (short_channel_id INTEGER NOT NULL PRIMARY KEY)")
}

Expand Down Expand Up @@ -82,7 +96,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb {
}

override def addChannel(c: ChannelAnnouncement, txid: ByteVector32, capacity: Satoshi): Unit = {
using(sqlite.prepareStatement("INSERT OR IGNORE INTO channels VALUES (?, ?, ?, ?)")) { statement =>
using(sqlite.prepareStatement("INSERT OR IGNORE INTO channels VALUES (?, ?, ?, ?, NULL, NULL)")) { statement =>
statement.setLong(1, c.shortChannelId.toLong)
statement.setString(2, txid.toHex)
statement.setBytes(3, channelAnnouncementCodec.encode(c).require.toByteArray)
Expand All @@ -91,56 +105,39 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb {
}
}

override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = {

def removeChannelsInternal(shortChannelIds: Iterable[ShortChannelId]): Unit = {
val ids = shortChannelIds.map(_.toLong).mkString(",")
using(sqlite.createStatement) { statement =>
statement.execute("BEGIN TRANSACTION")
statement.executeUpdate(s"DELETE FROM channel_updates WHERE short_channel_id IN ($ids)")
statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id IN ($ids)")
statement.execute("COMMIT TRANSACTION")
}
override def updateChannel(u: ChannelUpdate): Unit = {
val column = if (u.isNode1) "channel_update_1" else "channel_update_2"
using(sqlite.prepareStatement(s"UPDATE channels SET $column=? WHERE short_channel_id=?")) { statement =>
statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray)
statement.setLong(2, u.shortChannelId.toLong)
statement.executeUpdate()
}

// remove channels by batch of 1000
shortChannelIds.grouped(1000).foreach(removeChannelsInternal)
}

override def listChannels(): Map[ChannelAnnouncement, (ByteVector32, Satoshi)] = {
override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = {
using(sqlite.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT data, txid, capacity_sat FROM channels")
var m: Map[ChannelAnnouncement, (ByteVector32, Satoshi)] = Map()
val rs = statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels")
var m = SortedMap.empty[ShortChannelId, PublicChannel]
while (rs.next()) {
m += (channelAnnouncementCodec.decode(BitVector(rs.getBytes("data"))).require.value ->
(ByteVector32.fromValidHex(rs.getString("txid")), Satoshi(rs.getLong("capacity_sat"))))
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
val capacity = rs.getLong("capacity_sat")
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
m = m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt))
}
m
}
}

override def addChannelUpdate(u: ChannelUpdate): Unit = {
using(sqlite.prepareStatement("INSERT OR IGNORE INTO channel_updates VALUES (?, ?, ?)")) { statement =>
statement.setLong(1, u.shortChannelId.toLong)
statement.setBoolean(2, Announcements.isNode1(u.channelFlags))
statement.setBytes(3, channelUpdateCodec.encode(u).require.toByteArray)
statement.executeUpdate()
}
}

override def updateChannelUpdate(u: ChannelUpdate): Unit = {
using(sqlite.prepareStatement("UPDATE channel_updates SET data=? WHERE short_channel_id=? AND node_flag=?")) { statement =>
statement.setBytes(1, channelUpdateCodec.encode(u).require.toByteArray)
statement.setLong(2, u.shortChannelId.toLong)
statement.setBoolean(3, Announcements.isNode1(u.channelFlags))
statement.executeUpdate()
}
}

override def listChannelUpdates(): Seq[ChannelUpdate] = {
using(sqlite.createStatement()) { statement =>
val rs = statement.executeQuery("SELECT data FROM channel_updates")
codecSequence(rs, channelUpdateCodec)
override def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit = {
using(sqlite.createStatement) { statement =>
shortChannelIds
.grouped(1000) // remove channels by batch of 1000
.foreach {group =>
val ids = shortChannelIds.map(_.toLong).mkString(",")
statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id IN ($ids)")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ object SqliteUtils {

case class ExtendedResultSet(rs: ResultSet) {

def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_))

def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel))

def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel)))
Expand Down
6 changes: 3 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
// we won't clean it up, but we won't remember the temporary id on channel termination
stay using d.copy(channels = d.channels + (FinalChannelId(channelId) -> channel))

case Event(RoutingState(channels, updates, nodes), d: ConnectedData) =>
case Event(RoutingState(channels, nodes), d: ConnectedData) =>
// let's send the messages
def send(announcements: Iterable[_ <: LightningMessage]) = announcements.foldLeft(0) {
case (c, ann) =>
Expand All @@ -343,9 +343,9 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
}

log.info(s"sending all announcements to {}", remoteNodeId)
val channelsSent = send(channels)
val channelsSent = send(channels.map(_.ann))
val nodesSent = send(nodes)
val updatesSent = send(updates)
val updatesSent = send(channels.flatMap(c => c.update_1_opt.toSeq ++ c.update_2_opt.toSeq))
log.info(s"sent all announcements to {}: channels={} updates={} nodes={}", remoteNodeId, channelsSent, updatesSent, nodesSent)
stay

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.crypto.Sphinx.DecryptedFailurePacket
import fr.acinq.eclair.payment.PaymentLifecycle.{PaymentFailed, PaymentResult, RemoteFailure, SendPayment}
import fr.acinq.eclair.router.{Announcements, Data}
import fr.acinq.eclair.router.{Announcements, Data, PublicChannel}
import fr.acinq.eclair.wire.IncorrectOrUnknownPaymentDetails
import fr.acinq.eclair.{MilliSatoshi, NodeParams, randomBytes32, secureRandom}

Expand Down Expand Up @@ -89,9 +89,10 @@ object Autoprobe {

def pickPaymentDestination(nodeId: PublicKey, routingData: Data): Option[PublicKey] = {
// we only pick direct peers with enabled public channels
val peers = routingData.updates
val peers = routingData.channels
.collect {
case (desc, u) if desc.a == nodeId && Announcements.isEnabled(u.channelFlags) && routingData.channels.contains(u.shortChannelId) => desc.b // we only consider outgoing channels that are enabled and announced
case (shortChannelId, c@PublicChannel(ann, _, _, Some(u1), _))
if c.getNodeIdSameSideAs(u1) == nodeId && Announcements.isEnabled(u1.channelFlags) && routingData.channels.exists(_._1 == shortChannelId) => ann.nodeId2 // we only consider outgoing channels that are enabled and announced
}
if (peers.isEmpty) {
None
Expand Down
32 changes: 23 additions & 9 deletions eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge}
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.wire.ChannelUpdate

import scala.collection.immutable.SortedMap
import scala.collection.mutable

object Graph {
Expand Down Expand Up @@ -74,6 +75,7 @@ object Graph {
targetNode: PublicKey,
amount: MilliSatoshi,
ignoredEdges: Set[ChannelDesc],
ignoredVertices: Set[PublicKey],
extraEdges: Set[GraphEdge],
pathsToFind: Int,
wr: Option[WeightRatios],
Expand All @@ -89,7 +91,7 @@ object Graph {

// find the shortest path, k = 0
val initialWeight = RichWeight(cost = amount, 0, CltvExpiryDelta(0), 0)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, extraEdges, initialWeight, boundaries, currentBlockHeight, wr)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, initialWeight, boundaries, currentBlockHeight, wr)
shortestPaths += WeightedPath(shortestPath, pathWeight(shortestPath, amount, isPartial = false, currentBlockHeight, wr))

// avoid returning a list with an empty path
Expand Down Expand Up @@ -125,7 +127,7 @@ object Graph {
val returningEdges = rootPathEdges.lastOption.map(last => graph.getEdgesBetween(last.desc.b, last.desc.a)).toSeq.flatten.map(_.desc)

// find the "spur" path, a sub-path going from the spur edge to the target avoiding previously found sub-paths
val spurPath = dijkstraShortestPath(graph, spurEdge.desc.a, targetNode, ignoredEdges ++ edgesToIgnore.toSet ++ returningEdges.toSet, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr)
val spurPath = dijkstraShortestPath(graph, spurEdge.desc.a, targetNode, ignoredEdges ++ edgesToIgnore.toSet ++ returningEdges.toSet, ignoredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr)

// if there wasn't a path the spur will be empty
if (spurPath.nonEmpty) {
Expand Down Expand Up @@ -178,6 +180,7 @@ object Graph {
sourceNode: PublicKey,
targetNode: PublicKey,
ignoredEdges: Set[ChannelDesc],
ignoredVertices: Set[PublicKey],
extraEdges: Set[GraphEdge],
initialWeight: RichWeight,
boundaries: RichWeight => Boolean,
Expand Down Expand Up @@ -232,7 +235,7 @@ object Graph {
if (edge.update.htlcMaximumMsat.forall(newMinimumKnownWeight.cost <= _) &&
newMinimumKnownWeight.cost >= edge.update.htlcMinimumMsat &&
boundaries(newMinimumKnownWeight) && // check if this neighbor edge would break off the 'boundaries'
!ignoredEdges.contains(edge.desc)
!ignoredEdges.contains(edge.desc) && !ignoredVertices.contains(neighbor)
) {

// we call containsKey first because "getOrDefault" is not available in JDK7
Expand Down Expand Up @@ -533,21 +536,32 @@ object Graph {
def apply(edge: GraphEdge): DirectedGraph = new DirectedGraph(Map()).addEdge(edge.desc, edge.update)

def apply(edges: Seq[GraphEdge]): DirectedGraph = {
makeGraph(edges.map(e => e.desc -> e.update).toMap)
DirectedGraph().addEdges(edges.map(e => (e.desc, e.update)))
}

// optimized constructor
def makeGraph(descAndUpdates: Map[ChannelDesc, ChannelUpdate]): DirectedGraph = {
def makeGraph(channels: SortedMap[ShortChannelId, PublicChannel]): DirectedGraph = {

// initialize the map with the appropriate size to avoid resizing during the graph initialization
val mutableMap = new {} with mutable.HashMap[PublicKey, List[GraphEdge]] {
override def initialSize: Int = descAndUpdates.size + 1
override def initialSize: Int = channels.size + 1
}

// add all the vertices and edges in one go
descAndUpdates.foreach { case (desc, update) =>
// create or update vertex (desc.b) and update its neighbor
mutableMap.put(desc.b, GraphEdge(desc, update) +: mutableMap.getOrElse(desc.b, List.empty[GraphEdge]))
channels.values.foreach { channel =>
channel.update_1_opt.foreach { u1 =>
val desc1 = Router.getDesc(u1, channel.ann)
addDescToMap(desc1, u1)
}

channel.update_2_opt.foreach { u2 =>
val desc2 = Router.getDesc(u2, channel.ann)
addDescToMap(desc2, u2)
}
}

def addDescToMap(desc: ChannelDesc, u: ChannelUpdate) = {
mutableMap.put(desc.b, GraphEdge(desc, u) +: mutableMap.getOrElse(desc.b, List.empty[GraphEdge]))
mutableMap.get(desc.a) match {
case None => mutableMap += desc.a -> List.empty[GraphEdge]
case _ =>
Expand Down
Loading

0 comments on commit 8f7a415

Please sign in to comment.