From b4443839bf8502e0a6a3b9f57c215ad7771b14b3 Mon Sep 17 00:00:00 2001 From: Galen Collins Date: Mon, 18 Jun 2012 14:30:00 -0500 Subject: [PATCH] fixes #9 on github --- pymodbus/client/async.py | 61 +++++++++++++++++++++++++--------------- pymodbus/server/async.py | 1 + pymodbus/transaction.py | 25 ++++++---------- test/test_transaction.py | 1 - 4 files changed, 48 insertions(+), 40 deletions(-) diff --git a/pymodbus/client/async.py b/pymodbus/client/async.py index 3220c91b6..8e9f0c52d 100644 --- a/pymodbus/client/async.py +++ b/pymodbus/client/async.py @@ -32,7 +32,6 @@ def process(): reactor.callLater(1, process) reactor.run() """ -from collections import deque from twisted.internet import defer, protocol from pymodbus.factory import ClientDecoder from pymodbus.exceptions import ConnectionException @@ -67,7 +66,7 @@ def __init__(self, framer=None): :param framer: The framer to use for the protocol ''' self.framer = framer or ModbusSocketFramer(ClientDecoder()) - self._requests = deque() # link queue to tid + self._requests = {} self._connected = False def connectionMade(self): @@ -83,8 +82,8 @@ def connectionLost(self, reason): ''' _logger.debug("Client disconnected from modbus server: %s" % reason) self._connected = False - while self._requests: - self._requests.popleft().errback(Failure( + for key in self._requests: + self._requests.pop(key).errback(Failure( ConnectionException('Connection lost during request'))) def dataReceived(self, data): @@ -92,26 +91,35 @@ def dataReceived(self, data): :param data: The data returned from the server ''' - def _callback(reply): # todo errback/callback - if self._requests: - self._requests.popleft().callback(reply) - - self.framer.processIncomingPacket(data, _callback) + self.framer.processIncomingPacket(data, self._handleResponse) def execute(self, request): ''' Starts the producer to send the next request to consumer.write(Frame(request)) ''' request.transaction_id = _manager.getNextTID() - #self.handler[request.transaction_id] = request packet = self.framer.buildPacket(request) self.transport.write(packet) - return self._buildResponse() + return self._buildResponse(request.transaction_id) - def _buildResponse(self): + def _handleResponse(self, reply): + ''' Handle the processed response and link to correct deferred + + :param reply: The reply to process + ''' + if self._requests and reply: + tid = reply.transaction_id + handler = self.requests.pop(tid, None) + if handler: + handler.callback(reply) + else: _logger.debug("Unrequested message: " + str(reply)) + # TODO errback handled somewhere + + def _buildResponse(self, tid): ''' Helper method to return a deferred response for the current request. + :param tid: The transaction identifier for this response :returns: A defer linked to the latest request ''' if not self._connected: @@ -119,7 +127,7 @@ def _buildResponse(self): ConnectionException('Client is not connected'))) d = defer.Deferred() - self._requests.append(d) + self._requests[tid] = d # TODO add request here as well return d #----------------------------------------------------------------------# @@ -153,31 +161,40 @@ def datagramReceived(self, data, (host, port)): :param data: The data returned from the server ''' - def _callback(reply): # todo errback/callback - if self._requests: - self._requests.popleft().callback(reply) - _logger.debug("Datagram from: %s:%d" % (host, port)) - self.framer.processIncomingPacket(data, _callback) + self.framer.processIncomingPacket(data, self._handleResponse) def execute(self, request): ''' Starts the producer to send the next request to consumer.write(Frame(request)) ''' request.transaction_id = _manager.getNextTID() - #self.handler[request.transaction_id] = request packet = self.framer.buildPacket(request) self.transport.write(packet) - return self._buildResponse() + return self._buildResponse(request.transaction_id) + + def _handleResponse(self, reply): + ''' Handle the processed response and link to correct deferred - def _buildResponse(self): + :param reply: The reply to process + ''' + if self._requests and reply: + tid = reply.transaction_id + handler = self.requests.pop(tid, None) + if handler: + handler.callback(reply) + else: _logger.debug("Unrequested message: " + str(reply)) + # TODO errback handled somewhere + + def _buildResponse(self, tid): ''' Helper method to return a deferred response for the current request. + :param tid: The transaction identifier for this response :returns: A defer linked to the latest request ''' d = defer.Deferred() - self._requests.append(d) + self._requests[tid] = d # TODO add request here as well return d diff --git a/pymodbus/server/async.py b/pymodbus/server/async.py index e5fc9a28a..12a8b627d 100644 --- a/pymodbus/server/async.py +++ b/pymodbus/server/async.py @@ -217,6 +217,7 @@ def StartUdpServer(context, identity=None): def StartSerialServer(context, identity=None, framer=ModbusAsciiFramer, **kwargs): ''' Helper method to start the Modbus Async Serial server + :param context: The server data context :param identify: The server identity to use (default empty) :param framer: The framer to use (default ModbusAsciiFramer) diff --git a/pymodbus/transaction.py b/pymodbus/transaction.py index 9bfc4d779..c3aaa3861 100644 --- a/pymodbus/transaction.py +++ b/pymodbus/transaction.py @@ -21,7 +21,7 @@ #---------------------------------------------------------------------------# # The Global Transaction Manager #---------------------------------------------------------------------------# -class ModbusTransactionManager(Singleton): +class ModbusTransactionManager(object): ''' Impelements a transaction for a manager The transaction protocol can be represented by the following pseudo code:: @@ -38,7 +38,7 @@ class ModbusTransactionManager(Singleton): ''' __tid = Defaults.TransactionId - __transactions = [] + __transactions = {} def __init__(self, client=None): ''' Initializes an instance of the ModbusTransactionManager @@ -51,11 +51,6 @@ def execute(self, request): ''' Starts the producer to send the next request to consumer.write(Frame(request)) ''' - def _set_result(message): - ''' a helper method so I can reuse the async framers''' - self.response = message - - self.response = None retries = Defaults.Retries request.transaction_id = self.getNextTID() _logger.debug("Running transaction %d" % request.transaction_id) @@ -68,13 +63,13 @@ def _set_result(message): # as this may not read the full result set, but right now # it should be fine... result = self.client._recv(1024) - self.client.framer.processIncomingPacket(result, _set_result) + self.client.framer.processIncomingPacket(result, self.addTransaction) break; except socket.error, msg: self.client.close() _logger.debug("Transaction failed. (%s) " % msg) retries -= 1 - return self.response + return self.getTransaction(request.transaction_id) def addTransaction(self, request): ''' Adds a transaction to the handler @@ -84,7 +79,8 @@ def addTransaction(self, request): :param request: The request to hold on to ''' - ModbusTransactionManager.__transactions.append(request) + tid = request.transaction_id + ModbusTransactionManager.__transactions[tid] = request def getTransaction(self, tid): ''' Returns a transaction matching the referenced tid @@ -93,19 +89,14 @@ def getTransaction(self, tid): :param tid: The transaction to retrieve ''' - for k, v in enumerate(ModbusTransactionManager.__transactions): - if v.transaction_id == tid: - return ModbusTransactionManager.__transactions.pop(k) - return None + return ModbusTransactionManager.__transactions.pop(tid, None) def delTransaction(self, tid): ''' Removes a transaction matching the referenced tid :param tid: The transaction to remove ''' - for k, v in enumerate(ModbusTransactionManager.__transactions): - if v.transaction_id == tid: - del ModbusTransactionManager.__transactions[k] + ModbusTransactionManager.__transactions.pop(tid, None) def getNextTID(self): ''' Retrieve the next unique transaction identifier diff --git a/test/test_transaction.py b/test/test_transaction.py index 07905a572..c89c1873e 100644 --- a/test/test_transaction.py +++ b/test/test_transaction.py @@ -34,7 +34,6 @@ def tearDown(self): #---------------------------------------------------------------------------# def testModbusTransactionManagerTID(self): ''' Test the tcp transaction manager TID ''' - self.assertEqual(id(self._manager), id(ModbusTransactionManager())) for tid in range(1, self._manager.getNextTID() + 10): self.assertEqual(tid+2, self._manager.getNextTID()) self._manager.resetTID()