Skip to content

Commit

Permalink
Ping candidates and remove bad nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
lavrov committed Sep 15, 2024
1 parent 5abd251 commit 826fc2d
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 175 deletions.
37 changes: 11 additions & 26 deletions cmd/src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,12 @@ object Main
async[ResourceIO] {
given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await

val selfId = Resource.eval(NodeId.generate[IO]).await
val selfPeerId = Resource.eval(PeerId.generate[IO]).await
val infoHash = Resource.eval(infoHashFromString(infoHashOption)).await
val table = Resource.eval(RoutingTable[IO](selfId)).await
val node = Node(selfId, none, QueryHandler(selfId, table)).await
Resource.eval(RoutingTableBootstrap[IO](table, node.client)).await
val discovery = PeerDiscovery.make(table, node.client).await
val node = Node().await

val swarm = Swarm(
discovery.discover(infoHash),
node.discovery.discover(infoHash),
Connection.connect(selfPeerId, _, infoHash)
).await
val metadata = DownloadMetadata(swarm).toResource.await
Expand Down Expand Up @@ -132,23 +128,16 @@ object Main
throw new Exception("Missing info-hash")

given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await

val selfId = Resource.eval(NodeId.generate[IO]).await
val selfPeerId = Resource.eval(PeerId.generate[IO]).await
val peerAddress = peerAddressOption.flatMap(SocketAddress.fromStringIp)
val peers: Stream[IO, PeerInfo] =
peerAddress match
case Some(peerAddress) =>
Stream.emit(PeerInfo(peerAddress)).covary[IO]
case None =>
val bootstrapNodeAddress = dhtNodeAddressOption
.map(SocketAddress.fromString(_).toList)
.getOrElse(RoutingTableBootstrap.PublicBootstrapNodes)
val table = Resource.eval(RoutingTable[IO](selfId)).await
val node = Node(selfId, none, QueryHandler(selfId, table)).await
Resource.eval(RoutingTableBootstrap(table, node.client, bootstrapNodeAddress)).await
val discovery = PeerDiscovery.make(table, node.client).await
discovery.discover(infoHash)
val bootstrapNodeAddress = dhtNodeAddressOption.flatMap(SocketAddress.fromString)
val node = Node(none, bootstrapNodeAddress).await
node.discovery.discover(infoHash)
val swarm = Swarm(peers, peerInfo => Connection.connect(selfPeerId, peerInfo, infoHash)).await
val metadata =
torrentFile match
Expand Down Expand Up @@ -278,11 +267,7 @@ object Main
async[ResourceIO] {
val port = Port.fromInt(portParam).liftTo[ResourceIO](new Exception("Invalid port")).await
given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
val selfId = Resource.eval(NodeId.generate[IO]).await
val table = Resource.eval(RoutingTable[IO](selfId)).await
val node = Node(selfId, Some(port), QueryHandler(selfId, table)).await
Resource.eval(RoutingTableBootstrap(table, node.client)).await
PingRoutine(table, node.client).runForever.background.await
Node(Some(port)).await
}.useForever
}
}
Expand All @@ -299,13 +284,13 @@ object Main
val nodeAddress = SocketAddress.fromString(nodeAddressParam).liftTo[ResourceIO](new Exception("Invalid address")).await
val nodeIpAddress = nodeAddress.resolve[IO].toResource.await
given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await
val infoHash = infoHashFromString(infoHashParam).toResource.await
val selfId = Resource.eval(NodeId.generate[IO]).await
val table = Resource.eval(RoutingTable[IO](selfId)).await
val node = Node(selfId, none, QueryHandler(selfId, table)).await
val infoHash = infoHashFromString(infoHashParam).toResource.await
val messageSocket = MessageSocket(none).await
val client = Client(selfId, messageSocket, QueryHandler.noop).await
async[IO]:
val pong = node.client.ping(nodeIpAddress).await
val response = node.client.getPeers(NodeInfo(pong.id, nodeIpAddress), infoHash).await
val pong = client.ping(nodeIpAddress).await
val response = client.getPeers(NodeInfo(pong.id, nodeIpAddress), infoHash).await
IO.println(response).await
ExitCode.Success
}.useEval
Expand Down
Original file line number Diff line number Diff line change
@@ -1,77 +1,104 @@
package com.github.torrentdam.bittorrent.dht

import cats.effect.kernel.Temporal
import cats.effect.Concurrent
import cats.effect.Resource
import cats.effect.Sync
import cats.effect.std.{Queue, Random}
import cats.effect.{Concurrent, IO, Resource, Sync}
import cats.syntax.all.*
import com.comcast.ip4s.*
import com.github.torrentdam.bittorrent.InfoHash

import java.net.InetSocketAddress
import org.legogroup.woof.given
import org.legogroup.woof.Logger
import scodec.bits.ByteVector

trait Client[F[_]] {
trait Client {

def id: NodeId

def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): F[Either[Response.Nodes, Response.Peers]]
def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]]

def findNodes(nodeInfo: NodeInfo, target: NodeId): F[Response.Nodes]
def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes]

def ping(address: SocketAddress[IpAddress]): F[Response.Ping]
def ping(address: SocketAddress[IpAddress]): IO[Response.Ping]

def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): F[Either[Response.Nodes, Response.SampleInfoHashes]]
def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]]
}

object Client {

def apply[F[_]](
def generateTransactionId(using random: Random[IO]): IO[ByteVector] =
val nextChar = random.nextAlphaNumeric
(nextChar, nextChar).mapN((a, b) => ByteVector.encodeAscii(List(a, b).mkString).toOption.get)

def apply(
selfId: NodeId,
sendQueryMessage: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit],
receiveResponse: F[(SocketAddress[IpAddress], Either[Message.ErrorMessage, Message.ResponseMessage])],
generateTransactionId: F[ByteVector]
)(using
F: Temporal[F],
logger: Logger[F]
): Resource[F, Client[F]] = {
for {
messageSocket: MessageSocket,
queryHandler: QueryHandler[IO]
)(using Logger[IO], Random[IO]): Resource[IO, Client] = {
for
responses <- Resource.eval {
Queue.unbounded[IO, (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage)]
}
requestResponse <- RequestResponse.make(
generateTransactionId,
sendQueryMessage,
receiveResponse
messageSocket.writeMessage,
responses.take
)
} yield new Client[F] {
_ <-
messageSocket.readMessage
.flatMap {
case (a, m: Message.QueryMessage) =>
Logger[IO].debug(s"Received $m") >>
queryHandler(a, m.query).flatMap {
case Some(response) =>
val responseMessage = Message.ResponseMessage(m.transactionId, response)
Logger[IO].debug(s"Responding with $responseMessage") >>
messageSocket.writeMessage(a, responseMessage)
case None =>
Logger[IO].debug(s"No response for $m")
}
case (a, m: Message.ResponseMessage) => responses.offer((a, m))
case (a, m: Message.ErrorMessage) => responses.offer((a, m))
}
.recoverWith { case e: Throwable =>
Logger[IO].debug(s"Failed to read message: $e")
}
.foreverM
.background
yield new Client {

def id: NodeId = selfId

def getPeers(
nodeInfo: NodeInfo,
infoHash: InfoHash
): F[Either[Response.Nodes, Response.Peers]] =
): IO[Either[Response.Nodes, Response.Peers]] =
requestResponse.sendQuery(nodeInfo.address, Query.GetPeers(selfId, infoHash)).flatMap {
case nodes: Response.Nodes => nodes.asLeft.pure
case peers: Response.Peers => peers.asRight.pure
case _ => F.raiseError(InvalidResponse())
case _ => IO.raiseError(InvalidResponse())
}

def findNodes(nodeInfo: NodeInfo, target: NodeId): F[Response.Nodes] =
def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] =
requestResponse.sendQuery(nodeInfo.address, Query.FindNode(selfId, target)).flatMap {
case nodes: Response.Nodes => nodes.pure
case _ => Concurrent[F].raiseError(InvalidResponse())
case _ => IO.raiseError(InvalidResponse())
}

def ping(address: SocketAddress[IpAddress]): F[Response.Ping] =
def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] =
requestResponse.sendQuery(address, Query.Ping(selfId)).flatMap {
case ping: Response.Ping => ping.pure
case _ => Concurrent[F].raiseError(InvalidResponse())
case _ => IO.raiseError(InvalidResponse())
}
def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): F[Either[Response.Nodes, Response.SampleInfoHashes]] =
def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
requestResponse.sendQuery(nodeInfo.address, Query.SampleInfoHashes(selfId, target)).flatMap {
case response: Response.SampleInfoHashes => response.asRight[Response.Nodes].pure
case response: Response.Nodes => response.asLeft[Response.SampleInfoHashes].pure
case _ => Concurrent[F].raiseError(InvalidResponse())
case _ => IO.raiseError(InvalidResponse())
}
}
}

case class BootstrapError(message: String) extends Throwable(message)
case class InvalidResponse() extends Throwable
}
109 changes: 73 additions & 36 deletions dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,94 @@ import cats.effect.Resource
import cats.effect.Sync
import cats.implicits.*
import com.comcast.ip4s.*
import fs2.io.net.DatagramSocketGroup
import fs2.Stream
import com.github.torrentdam.bittorrent.InfoHash

import java.net.InetSocketAddress
import org.legogroup.woof.given
import org.legogroup.woof.Logger
import scodec.bits.ByteVector

trait Node {
def client: Client[IO]
}
import scala.concurrent.duration.DurationInt

class Node(val id: NodeId, val client: Client, val routingTable: RoutingTable[IO], val discovery: PeerDiscovery)

object Node {

def apply(
selfId: NodeId,
port: Option[Port],
queryHandler: QueryHandler[IO]
port: Option[Port] = None,
bootstrapNodeAddress: Option[SocketAddress[Host]] = None
)(using
random: Random[IO],
logger: Logger[IO]
): Resource[IO, Node] =

def generateTransactionId: IO[ByteVector] =
val nextChar = random.nextAlphaNumeric
(nextChar, nextChar).mapN((a, b) => ByteVector.encodeAscii(List(a, b).mkString).toOption.get)

for
selfId <- Resource.eval(NodeId.generate[IO])
messageSocket <- MessageSocket(port)
responses <- Resource.eval {
Queue.unbounded[IO, (SocketAddress[IpAddress], Either[Message.ErrorMessage, Message.ResponseMessage])]
routingTable <- RoutingTable[IO](selfId).toResource
queryingNodes <- Queue.unbounded[IO, NodeInfo].toResource
queryHandler = reportingQueryHandler(queryingNodes, QueryHandler.simple(selfId, routingTable))
client <- Client(selfId, messageSocket, queryHandler)
insertingClient = new InsertingClient(client, routingTable)
bootstrapNodes = bootstrapNodeAddress.map(List(_)).getOrElse(RoutingTableBootstrap.PublicBootstrapNodes)
discovery = PeerDiscovery(routingTable, insertingClient)
_ <- RoutingTableBootstrap(routingTable, insertingClient, discovery, bootstrapNodes).toResource
_ <- PingRoutine(routingTable, client).runForever.background
_ <- pingCandidates(queryingNodes, client, routingTable).background
yield new Node(selfId, insertingClient, routingTable, discovery)

private class InsertingClient(client: Client, routingTable: RoutingTable[IO]) extends Client {

def id: NodeId = client.id

def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] =
client.getPeers(nodeInfo, infoHash) <* routingTable.insert(nodeInfo)

def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] =
client.findNodes(nodeInfo, target).flatTap { response =>
routingTable.insert(NodeInfo(response.id, nodeInfo.address))
}
client0 <- Client(selfId, messageSocket.writeMessage, responses.take, generateTransactionId)
_ <-
messageSocket.readMessage
.flatMap {
case (a, m: Message.QueryMessage) =>
logger.debug(s"Received $m") >>
queryHandler(a, m.query).flatMap { response =>
val responseMessage = Message.ResponseMessage(m.transactionId, response)
logger.debug(s"Responding with $responseMessage") >>
messageSocket.writeMessage(a, responseMessage)
}
case (a, m: Message.ResponseMessage) => responses.offer((a, m.asRight))
case (a, m: Message.ErrorMessage) => responses.offer((a, m.asLeft))
}
.recoverWith { case e: Throwable =>
logger.trace(s"Failed to read message: $e")
}
.foreverM
.background

yield new Node {
def client: Client[IO] = client0

def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] =
client.ping(address).flatTap { response =>
routingTable.insert(NodeInfo(response.id, address))
}

def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] =
client.sampleInfoHashes(nodeInfo, target).flatTap { response =>
routingTable.insert(
response match
case Left(response) => NodeInfo(response.id, nodeInfo.address)
case Right(response) => NodeInfo(response.id, nodeInfo.address)
)
}

override def toString: String = s"InsertingClient($client)"
}

private def pingCandidate(node: NodeInfo, client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) =
routingTable.lookup(node.id).flatMap {
case Some(_) => IO.unit
case None =>
Logger[IO].info(s"Pinging $node") *>
client.ping(node.address).timeout(5.seconds).attempt.flatMap {
case Right(_) =>
Logger[IO].info(s"Got pong from $node -- insert as good") *>
routingTable.insert(node)
case Left(_) => IO.unit
}
}

private def pingCandidates(nodes: Queue[IO, NodeInfo], client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) =
nodes.take.flatMap(pingCandidate(_, client, routingTable).attempt.void).foreverM


private def reportingQueryHandler(queue: Queue[IO, NodeInfo], next: QueryHandler[IO]): QueryHandler[IO] = (address, query) =>
val nodeInfo = query match
case Query.Ping(id) => NodeInfo(id, address)
case Query.FindNode(id, _) => NodeInfo(id, address)
case Query.GetPeers(id, _) => NodeInfo(id, address)
case Query.AnnouncePeer(id, _, _) => NodeInfo(id, address)
case Query.SampleInfoHashes(id, _) => NodeInfo(id, address)
queue.offer(nodeInfo) *> next(address, query)
}
Loading

0 comments on commit 826fc2d

Please sign in to comment.