Skip to content

Commit

Permalink
fixes #9 on github
Browse files Browse the repository at this point in the history
  • Loading branch information
bashwork committed Jun 18, 2012
1 parent 3b191bd commit b444383
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 40 deletions.
61 changes: 39 additions & 22 deletions pymodbus/client/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def process():
reactor.callLater(1, process)
reactor.run()
"""
from collections import deque
from twisted.internet import defer, protocol
from pymodbus.factory import ClientDecoder
from pymodbus.exceptions import ConnectionException
Expand Down Expand Up @@ -67,7 +66,7 @@ def __init__(self, framer=None):
:param framer: The framer to use for the protocol
'''
self.framer = framer or ModbusSocketFramer(ClientDecoder())
self._requests = deque() # link queue to tid
self._requests = {}
self._connected = False

def connectionMade(self):
Expand All @@ -83,43 +82,52 @@ def connectionLost(self, reason):
'''
_logger.debug("Client disconnected from modbus server: %s" % reason)
self._connected = False
while self._requests:
self._requests.popleft().errback(Failure(
for key in self._requests:
self._requests.pop(key).errback(Failure(
ConnectionException('Connection lost during request')))

def dataReceived(self, data):
''' Get response, check for valid message, decode result
:param data: The data returned from the server
'''
def _callback(reply): # todo errback/callback
if self._requests:
self._requests.popleft().callback(reply)

self.framer.processIncomingPacket(data, _callback)
self.framer.processIncomingPacket(data, self._handleResponse)

def execute(self, request):
''' Starts the producer to send the next request to
consumer.write(Frame(request))
'''
request.transaction_id = _manager.getNextTID()
#self.handler[request.transaction_id] = request
packet = self.framer.buildPacket(request)
self.transport.write(packet)
return self._buildResponse()
return self._buildResponse(request.transaction_id)

def _buildResponse(self):
def _handleResponse(self, reply):
''' Handle the processed response and link to correct deferred
:param reply: The reply to process
'''
if self._requests and reply:
tid = reply.transaction_id
handler = self.requests.pop(tid, None)
if handler:
handler.callback(reply)
else: _logger.debug("Unrequested message: " + str(reply))
# TODO errback handled somewhere

def _buildResponse(self, tid):
''' Helper method to return a deferred response
for the current request.
:param tid: The transaction identifier for this response
:returns: A defer linked to the latest request
'''
if not self._connected:
return defer.fail(Failure(
ConnectionException('Client is not connected')))

d = defer.Deferred()
self._requests.append(d)
self._requests[tid] = d # TODO add request here as well
return d

#----------------------------------------------------------------------#
Expand Down Expand Up @@ -153,31 +161,40 @@ def datagramReceived(self, data, (host, port)):
:param data: The data returned from the server
'''
def _callback(reply): # todo errback/callback
if self._requests:
self._requests.popleft().callback(reply)

_logger.debug("Datagram from: %s:%d" % (host, port))
self.framer.processIncomingPacket(data, _callback)
self.framer.processIncomingPacket(data, self._handleResponse)

def execute(self, request):
''' Starts the producer to send the next request to
consumer.write(Frame(request))
'''
request.transaction_id = _manager.getNextTID()
#self.handler[request.transaction_id] = request
packet = self.framer.buildPacket(request)
self.transport.write(packet)
return self._buildResponse()
return self._buildResponse(request.transaction_id)

def _handleResponse(self, reply):
''' Handle the processed response and link to correct deferred
def _buildResponse(self):
:param reply: The reply to process
'''
if self._requests and reply:
tid = reply.transaction_id
handler = self.requests.pop(tid, None)
if handler:
handler.callback(reply)
else: _logger.debug("Unrequested message: " + str(reply))
# TODO errback handled somewhere

def _buildResponse(self, tid):
''' Helper method to return a deferred response
for the current request.
:param tid: The transaction identifier for this response
:returns: A defer linked to the latest request
'''
d = defer.Deferred()
self._requests.append(d)
self._requests[tid] = d # TODO add request here as well
return d


Expand Down
1 change: 1 addition & 0 deletions pymodbus/server/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def StartUdpServer(context, identity=None):
def StartSerialServer(context, identity=None,
framer=ModbusAsciiFramer, **kwargs):
''' Helper method to start the Modbus Async Serial server
:param context: The server data context
:param identify: The server identity to use (default empty)
:param framer: The framer to use (default ModbusAsciiFramer)
Expand Down
25 changes: 8 additions & 17 deletions pymodbus/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#---------------------------------------------------------------------------#
# The Global Transaction Manager
#---------------------------------------------------------------------------#
class ModbusTransactionManager(Singleton):
class ModbusTransactionManager(object):
''' Impelements a transaction for a manager
The transaction protocol can be represented by the following pseudo code::
Expand All @@ -38,7 +38,7 @@ class ModbusTransactionManager(Singleton):
'''

__tid = Defaults.TransactionId
__transactions = []
__transactions = {}

def __init__(self, client=None):
''' Initializes an instance of the ModbusTransactionManager
Expand All @@ -51,11 +51,6 @@ def execute(self, request):
''' Starts the producer to send the next request to
consumer.write(Frame(request))
'''
def _set_result(message):
''' a helper method so I can reuse the async framers'''
self.response = message

self.response = None
retries = Defaults.Retries
request.transaction_id = self.getNextTID()
_logger.debug("Running transaction %d" % request.transaction_id)
Expand All @@ -68,13 +63,13 @@ def _set_result(message):
# as this may not read the full result set, but right now
# it should be fine...
result = self.client._recv(1024)
self.client.framer.processIncomingPacket(result, _set_result)
self.client.framer.processIncomingPacket(result, self.addTransaction)
break;
except socket.error, msg:
self.client.close()
_logger.debug("Transaction failed. (%s) " % msg)
retries -= 1
return self.response
return self.getTransaction(request.transaction_id)

def addTransaction(self, request):
''' Adds a transaction to the handler
Expand All @@ -84,7 +79,8 @@ def addTransaction(self, request):
:param request: The request to hold on to
'''
ModbusTransactionManager.__transactions.append(request)
tid = request.transaction_id
ModbusTransactionManager.__transactions[tid] = request

def getTransaction(self, tid):
''' Returns a transaction matching the referenced tid
Expand All @@ -93,19 +89,14 @@ def getTransaction(self, tid):
:param tid: The transaction to retrieve
'''
for k, v in enumerate(ModbusTransactionManager.__transactions):
if v.transaction_id == tid:
return ModbusTransactionManager.__transactions.pop(k)
return None
return ModbusTransactionManager.__transactions.pop(tid, None)

def delTransaction(self, tid):
''' Removes a transaction matching the referenced tid
:param tid: The transaction to remove
'''
for k, v in enumerate(ModbusTransactionManager.__transactions):
if v.transaction_id == tid:
del ModbusTransactionManager.__transactions[k]
ModbusTransactionManager.__transactions.pop(tid, None)

def getNextTID(self):
''' Retrieve the next unique transaction identifier
Expand Down
1 change: 0 additions & 1 deletion test/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def tearDown(self):
#---------------------------------------------------------------------------#
def testModbusTransactionManagerTID(self):
''' Test the tcp transaction manager TID '''
self.assertEqual(id(self._manager), id(ModbusTransactionManager()))
for tid in range(1, self._manager.getNextTID() + 10):
self.assertEqual(tid+2, self._manager.getNextTID())
self._manager.resetTID()
Expand Down

0 comments on commit b444383

Please sign in to comment.