Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix improper sorting when getting the closest peers to a hash #1282

Merged
merged 2 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ at anytime.
* `blob_list` raising an error when blobs in a stream haven't yet been created
* stopping a download from raising `NoneType object has no attribute finished_deferred`
* file manager startup locking up when there are many files for some channels
* improper sorting when getting the closest peers to a hash

### Deprecated
*
Expand Down
4 changes: 2 additions & 2 deletions lbrynet/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def findNode(self, rpc_contact, key):
if len(key) != constants.key_bits / 8:
raise ValueError("invalid contact id length: %i" % len(key))

contacts = self._routingTable.findCloseNodes(key, constants.k, rpc_contact.id)
contacts = self._routingTable.findCloseNodes(key, sender_node_id=rpc_contact.id)
contact_triples = []
for contact in contacts:
contact_triples.append((contact.id, contact.address, contact.port))
Expand Down Expand Up @@ -644,7 +644,7 @@ def _iterativeFind(self, key, startupShortlist=None, rpc='findNode'):
raise ValueError("invalid key length: %i" % len(key))

if startupShortlist is None:
shortlist = self._routingTable.findCloseNodes(key, constants.k)
shortlist = self._routingTable.findCloseNodes(key)
# if key != self.node_id:
# # Update the "last accessed" timestamp for the appropriate k-bucket
# self._routingTable.touchKBucket(key)
Expand Down
52 changes: 12 additions & 40 deletions lbrynet/dht/routingtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ def replaceContact(failure, deadContact):
self.touchKBucketByIndex(bucketIndex)
return defer.succeed(None)

def findCloseNodes(self, key, count, sender_node_id=None):
def findCloseNodes(self, key, count=None, sender_node_id=None):
""" Finds a number of known nodes closest to the node/value with the
specified key.

@param key: the n-bit key (i.e. the node or value ID) to search for
@type key: str
@param count: the amount of contacts to return
@param count: the amount of contacts to return, default of k (8)
@type count: int
@param sender_node_id: Used during RPC, this is be the sender's Node ID
Whatever ID is passed in the paramater will get
Expand All @@ -161,45 +161,17 @@ def findCloseNodes(self, key, count, sender_node_id=None):
node is returning all of the contacts that it knows of.
@rtype: list
"""
bucketIndex = self._kbucketIndex(key)

if bucketIndex < len(self._buckets):
# sort these
closestNodes = self._buckets[bucketIndex].getContacts(count, sender_node_id, sort_distance_to=key)
else:
closestNodes = []
# This method must return k contacts (even if we have the node
# with the specified key as node ID), unless there is less
# than k remote nodes in the routing table
i = 1
canGoLower = bucketIndex - i >= 0
canGoHigher = bucketIndex + i < len(self._buckets)

def get_remain(closest):
return min(count, constants.k) - len(closest)

exclude = [self._parentNodeID]
if sender_node_id:
exclude.append(sender_node_id)
if key in exclude:
exclude.remove(key)
count = count or constants.k
distance = Distance(key)

while len(closestNodes) < min(count, constants.k) and (canGoLower or canGoHigher):
iteration_contacts = []
# get contacts from lower and/or higher buckets without sorting them
if canGoLower and len(closestNodes) < min(count, constants.k):
lower_bucket = self._buckets[bucketIndex - i]
contacts = lower_bucket.getContacts(get_remain(closestNodes), sender_node_id, sort_distance_to=False)
iteration_contacts.extend(contacts)
canGoLower = bucketIndex - (i + 1) >= 0

if canGoHigher and len(closestNodes) < min(count, constants.k):
higher_bucket = self._buckets[bucketIndex + i]
contacts = higher_bucket.getContacts(get_remain(closestNodes), sender_node_id, sort_distance_to=False)
iteration_contacts.extend(contacts)
canGoHigher = bucketIndex + (i + 1) < len(self._buckets)
i += 1
# sort the combined contacts and add as many as possible/needed to the combined contact list
iteration_contacts.sort(key=lambda c: distance(c.id), reverse=True)
while len(iteration_contacts) and len(closestNodes) < min(count, constants.k):
closestNodes.append(iteration_contacts.pop())
return closestNodes
contacts = self.get_contacts()
contacts = [c for c in contacts if c.id not in exclude]
contacts.sort(key=lambda c: distance(c.id))
return contacts[:min(count, len(contacts))]

def getContact(self, contactID):
""" Returns the (known) contact with the specified node ID
Expand Down
2 changes: 1 addition & 1 deletion lbrynet/tests/functional/dht/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_storing_node_went_stale_then_came_back(self):
self.nodes.remove(announcing_node)
yield self.run_reactor(31, [announcing_node.stop()])
# run the network for an hour, which should expire the removed node and turn the announced value stale
self.pump_clock(constants.checkRefreshInterval * 4, constants.checkRefreshInterval/2)
self.pump_clock(constants.checkRefreshInterval * 5, constants.checkRefreshInterval/2)
self.verify_all_nodes_are_routable()

# make sure the contact isn't returned as a peer for the blob, but that we still have the entry in the
Expand Down
2 changes: 1 addition & 1 deletion lbrynet/tests/unit/dht/test_routingtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def testAddContact(self):
# Now add it...
yield self.routingTable.addContact(contact)
# ...and request the closest nodes to it (will retrieve it)
closestNodes = self.routingTable.findCloseNodes(contactID, constants.k)
closestNodes = self.routingTable.findCloseNodes(contactID)
self.failUnlessEqual(len(closestNodes), 1, 'Wrong amount of contacts returned; expected 1,'
' got %d' % len(closestNodes))
self.failUnless(contact in closestNodes, 'Added contact not found by issueing '
Expand Down