From f0f3958f14b0f6e34bf883986bd760551d024892 Mon Sep 17 00:00:00 2001 From: Jack Robison Date: Thu, 31 May 2018 10:50:11 -0400 Subject: [PATCH] add protocol version to the dht and migrate old arg format for store --- lbrynet/dht/constants.py | 2 ++ lbrynet/dht/contact.py | 4 +++ lbrynet/dht/node.py | 6 ++--- lbrynet/dht/protocol.py | 42 +++++++++++++++++++++++++++-- lbrynet/tests/unit/dht/test_node.py | 2 +- 5 files changed, 50 insertions(+), 6 deletions(-) diff --git a/lbrynet/dht/constants.py b/lbrynet/dht/constants.py index 2697e0d64d..e06aae1cdc 100644 --- a/lbrynet/dht/constants.py +++ b/lbrynet/dht/constants.py @@ -55,3 +55,5 @@ key_bits = 384 rpc_id_length = 20 + +protocolVersion = 1 diff --git a/lbrynet/dht/contact.py b/lbrynet/dht/contact.py index 736dfa4768..c338152e4c 100644 --- a/lbrynet/dht/contact.py +++ b/lbrynet/dht/contact.py @@ -34,6 +34,7 @@ def __init__(self, contactManager, id, ipAddress, udpPort, networkProtocol, firs self.getTime = self._contactManager._get_time self.lastReplied = None self.lastRequested = None + self.protocolVersion = None @property def lastInteracted(self): @@ -120,6 +121,9 @@ def update_last_failed(self): failures.append(self.getTime()) self._contactManager._rpc_failures[(self.address, self.port)] = failures + def update_protocol_version(self, version): + self.protocolVersion = version + def __str__(self): return '<%s.%s object; IP address: %s, UDP port: %d>' % ( self.__module__, self.__class__.__name__, self.address, self.port) diff --git a/lbrynet/dht/node.py b/lbrynet/dht/node.py index e4add4bdc6..98bcb91d96 100644 --- a/lbrynet/dht/node.py +++ b/lbrynet/dht/node.py @@ -315,7 +315,7 @@ def announceHaveBlob(self, blob_hash): self_contact = self.contact_manager.make_contact(self.node_id, self.externalIP, self.port, self._protocol) token = self.make_token(self_contact.compact_ip()) - yield self.store(self_contact, blob_hash, token, self.peerPort) + yield self.store(self_contact, blob_hash, token, self.peerPort, self.node_id, 0) elif self.externalIP is not None: pass else: @@ -328,7 +328,7 @@ def announce_to_contact(contact): known_nodes[contact.id] = contact try: responseMsg, originAddress = yield contact.findValue(blob_hash, rawResponse=True) - res = yield contact.store(blob_hash, responseMsg.response['token'], self.peerPort) + res = yield contact.store(blob_hash, responseMsg.response['token'], self.peerPort, self.node_id, 0) if res != "OK": raise ValueError(res) contacted.append(contact) @@ -506,7 +506,7 @@ def ping(self): return 'pong' @rpcmethod - def store(self, rpc_contact, blob_hash, token, port, originalPublisherID=None, age=0): + def store(self, rpc_contact, blob_hash, token, port, originalPublisherID, age): """ Store the received data in this node's local datastore @param blob_hash: The hash of the data diff --git a/lbrynet/dht/protocol.py b/lbrynet/dht/protocol.py index f7a39a8320..cc5cc43669 100644 --- a/lbrynet/dht/protocol.py +++ b/lbrynet/dht/protocol.py @@ -14,6 +14,41 @@ log = logging.getLogger(__name__) +def migrate_incoming_rpc_args(contact, method, *args): + if method == 'store': + if isinstance(args[-1], dict) and 'protocolVersion' in args[-1]: # args don't need reformatting + contact.update_protocol_version(args[-1].pop('protocolVersion')) # update the contact protocol version + return tuple(args[:-1]), args[-1] + elif isinstance(args[1], dict): # unpack the old 'value' dictionary argument + contact.update_protocol_version(0) # set the version for the contact so we know how to format our requests + blob_hash = args[0] + token = args[1].get('token', None) + port = args[1].get('port', -1) + originalPublisherID = args[1].get('lbryid', None) + age = 0 + return (blob_hash, token, port, originalPublisherID, age), {} + return args, {} + + +def migrate_outgoing_rpc_args(contact, method, *args): + if method == 'store' and contact.protocolVersion == 0: + blob_hash, token, port, originalPublisherID, age = args + old_value_arg = { + 'token': token, + 'port': port, + 'lbryid': originalPublisherID + } + return inject_protocol_version_argument(blob_hash, old_value_arg) + return args + + +def inject_protocol_version_argument(*args): + if args and isinstance(args[-1], dict): + args[-1]['protocolVersion'] = constants.protocolVersion + return args + return args + tuple({'protocolVersion': constants.protocolVersion}) + + class PingQueue(object): """ Schedules a 15 minute delayed ping after a new node sends us a query. This is so the new node gets added to the @@ -131,7 +166,7 @@ def sendRPC(self, contact, method, args, rawResponse=False): C{ErrorMessage}). @rtype: twisted.internet.defer.Deferred """ - msg = msgtypes.RequestMessage(self._node.node_id, method, args) + msg = msgtypes.RequestMessage(self._node.node_id, method, migrate_outgoing_rpc_args(contact, method, *args)) msgPrimitive = self._translator.toPrimitive(msg) encodedMsg = self._encoder.encode(msgPrimitive) @@ -152,6 +187,8 @@ def _remove_contact(failure): # remove the contact from the routing table and t except (ValueError, IndexError): pass contact.update_last_failed() + if failure.getErrorMessage() == "store() takes at least 5 arguments (4 given)": # lbrynet < 0.20.0 + contact.update_protocol_version(0) return failure def _update_contact(result): # refresh the contact in the routing table @@ -403,7 +440,8 @@ def handleResult(result): senderContact.address, senderContact.port) try: if method != 'ping': - result = func(senderContact, *args) + migrated_args, migrated_kwargs = migrate_incoming_rpc_args(senderContact, method, *args) + result = func(senderContact, *migrated_args, **migrated_kwargs) else: result = func() except Exception, e: diff --git a/lbrynet/tests/unit/dht/test_node.py b/lbrynet/tests/unit/dht/test_node.py index e04b07f9bb..f5fe876abd 100644 --- a/lbrynet/tests/unit/dht/test_node.py +++ b/lbrynet/tests/unit/dht/test_node.py @@ -62,7 +62,7 @@ def setUp(self): def testStore(self): """ Tests if the node can store (and privately retrieve) some data """ for key, port in self.cases: - yield self.node.store(self.contact, key, self.token, port, self.contact.id) + yield self.node.store(self.contact, key, self.token, port, self.contact.id, 0) for key, value in self.cases: expected_result = self.contact.compact_ip() + str(struct.pack('>H', value)) + \ self.contact.id