diff --git a/CHANGES.rst b/CHANGES.rst index 9b4c3afc..9297c9ba 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,27 @@ -Revision 0.4.8, released XX-09-2019 +Revision 0.5.0, released XX-09-2019 ----------------------------------- +- Refactor BER/CER/DER decoder into a coroutine. + + The goal of this change is to make the decoder stopping on input + data starvation and resuming from where it stopped whenever the + caller decides to try again (hopefully making sure that some more + input becomes available). + + This change makes it possible for the decoder to operate on streams + of data (meaning that the entire DER blob might not be immediately + available on input). + + On top of that, the decoder yields partially reconstructed ASN.1 + object on input starvation making it possible for the caller to + inspect what has been decoded so far and possibly consume partial + ASN.1 data. + + All these new feature are natively available through + `StreamingDecoder` class. Previously published API is implemented + as a thin wrapper on top of that ensuring backward compatibility. + - Added ability of combining `SingleValueConstraint` and `PermittedAlphabetConstraint` objects into one for proper modeling `FROM ... EXCEPT ...` ASN.1 clause. diff --git a/README.md b/README.md index e36324b0..b01801b9 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Features * Generic implementation of ASN.1 types (X.208) * Standards compliant BER/CER/DER codecs +* Can operate on streams of serialized data * Dumps/loads ASN.1 structures from Python types * 100% Python, works with Python 2.4 up to Python 3.7 * MT-safe diff --git a/pyasn1/codec/ber/decoder.py b/pyasn1/codec/ber/decoder.py index 2a6448eb..d3de8ff8 100644 --- a/pyasn1/codec/ber/decoder.py +++ b/pyasn1/codec/ber/decoder.py @@ -5,15 +5,14 @@ # License: http://snmplabs.com/pyasn1/license.html # import os -import sys -from io import BytesIO, BufferedReader, IOBase, DEFAULT_BUFFER_SIZE from pyasn1 import debug from pyasn1 import error +from pyasn1.codec import streaming from pyasn1.codec.ber import eoo from pyasn1.compat.integer import from_bytes from pyasn1.compat.octets import oct2int, octs2ints, ints2octs, null -from pyasn1.error import PyAsn1Error, UnsupportedSubstrateError +from pyasn1.error import PyAsn1Error from pyasn1.type import base from pyasn1.type import char from pyasn1.type import tag @@ -22,165 +21,16 @@ from pyasn1.type import useful -__all__ = ['decodeStream', 'decode'] +__all__ = ['StreamingDecoder', 'Decoder', 'decode'] LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER) noValue = base.noValue +SubstrateUnderrunError = error.SubstrateUnderrunError -_PY2 = sys.version_info < (3,) - -class _CachingStreamWrapper(IOBase): - """Wrapper around non-seekable streams. - - Note that the implementation is tied to the decoder, - not checking for dangerous arguments for the sake - of performance. - - The read bytes are kept in an internal cache until - setting _markedPosition which may reset the cache. - """ - def __init__(self, raw): - self._raw = raw - self._cache = BytesIO() - self._markedPosition_ = 0 - - def peek(self, n): - result = self.read(n) - self._cache.seek(-len(result), os.SEEK_CUR) - return result - - def seekable(self): - return True - - def seek(self, n=-1, whence=os.SEEK_SET): - # Note that this not safe for seeking forward. - return self._cache.seek(n, whence) - - def read(self, n=-1): - read_from_cache = self._cache.read(n) - if n != -1: - n -= len(read_from_cache) - if not n: # 0 bytes left to read - return read_from_cache - - read_from_raw = self._raw.read(n) - self._cache.write(read_from_raw) - return read_from_cache + read_from_raw - - @property - def _markedPosition(self): - """Position where the currently processed element starts. - - This is used for back-tracking in Decoder.__call__ - and (indefLen)ValueDecoder and should not be used for other purposes. - The client is not supposed to ever seek before this position. - """ - return self._markedPosition_ - - @_markedPosition.setter - def _markedPosition(self, value): - # By setting the value, we ensure we won't seek back before it. - # `value` should be the same as the current position - # We don't check for this for performance reasons. - self._markedPosition_ = value - - # Whenever we set _marked_position, we know for sure - # that we will not return back, and thus it is - # safe to drop all cached data. - if self._cache.tell() > DEFAULT_BUFFER_SIZE: - self._cache = BytesIO(self._cache.read()) - self._markedPosition_ = 0 - - def tell(self): - return self._cache.tell() - - -def _asSeekableStream(substrate): - """Convert object to seekable byte-stream. - - Parameters - ---------- - substrate: :py:class:`bytes` or :py:class:`io.IOBase` or :py:class:`univ.OctetString` - - Returns - ------- - : :py:class:`io.IOBase` - - Raises - ------ - ~pyasn1.error.PyAsn1Error - If the supplied substrate cannot be converted to a seekable stream. - """ - if isinstance(substrate, bytes): - return BytesIO(substrate) - elif isinstance(substrate, univ.OctetString): - return BytesIO(substrate.asOctets()) - try: - if _PY2 and isinstance(substrate, file): # Special case (it is not possible to set attributes) - return BufferedReader(substrate) - elif substrate.seekable(): # Will fail for most invalid types - return substrate - else: - return _CachingStreamWrapper(substrate) - except AttributeError: - raise UnsupportedSubstrateError("Cannot convert " + substrate.__class__.__name__ + " to a seekable bit stream.") - - -def _endOfStream(substrate): - """Check whether we have reached the end of a stream. - - Although it is more effective to read and catch exceptions, this - function - - Parameters - ---------- - substrate: :py:class:`IOBase` - Stream to check - - Returns - ------- - : :py:class:`bool` - """ - if isinstance(substrate, BytesIO): - cp = substrate.tell() - substrate.seek(0, os.SEEK_END) - result = substrate.tell() == cp - substrate.seek(cp, os.SEEK_SET) - return result - else: - return not substrate.peek(1) - - -def _peek(substrate, size=-1): - """Peek the stream. - - Parameters - ---------- - substrate: :py:class:`IOBase` - Stream to read from. - - size: :py:class:`int` - How many bytes to peek (-1 = all available) - - Returns - ------- - : :py:class:`bytes` or :py:class:`str` - The return type depends on Python major version - """ - if hasattr(substrate, "peek"): - return substrate.peek(size) - else: - current_position = substrate.tell() - try: - return substrate.read(size) - finally: - substrate.seek(current_position) - - -class AbstractDecoder(object): +class AbstractPayloadDecoder(object): protoComponent = None def valueDecoder(self, substrate, asn1Spec, @@ -189,10 +39,9 @@ def valueDecoder(self, substrate, asn1Spec, **options): """Decode value with fixed byte length. - If the decoder does not consume a precise byte length, - it is considered an error. + The decoder is allowed to consume as many bytes as necessary. """ - raise error.PyAsn1Error('Decoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError? + raise error.PyAsn1Error('SingleItemDecoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError? def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -204,11 +53,19 @@ def indefLenValueDecoder(self, substrate, asn1Spec, """ raise error.PyAsn1Error('Indefinite length mode decoder not implemented for %s' % (tagSet,)) # TODO: Seems more like an NotImplementedError? + @staticmethod + def _passAsn1Object(asn1Object, options): + if 'asn1Object' not in options: + options['asn1Object'] = asn1Object + + return options + -class AbstractSimpleDecoder(AbstractDecoder): +class AbstractSimplePayloadDecoder(AbstractPayloadDecoder): @staticmethod - def substrateCollector(asn1Object, substrate, length): - return substrate.read(length) + def substrateCollector(asn1Object, substrate, length, options): + for chunk in streaming.read(substrate, length, options): + yield chunk def _createComponent(self, asn1Spec, tagSet, value, **options): if options.get('native'): @@ -221,7 +78,7 @@ def _createComponent(self, asn1Spec, tagSet, value, **options): return asn1Spec.clone(value) -class ExplicitTagDecoder(AbstractSimpleDecoder): +class RawPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.Any('') def valueDecoder(self, substrate, asn1Spec, @@ -229,43 +86,45 @@ def valueDecoder(self, substrate, asn1Spec, decodeFun=None, substrateFun=None, **options): if substrateFun: - return substrateFun( - self._createComponent(asn1Spec, tagSet, '', **options), - substrate, length - ) - value = decodeFun(substrate, asn1Spec, tagSet, length, **options) + asn1Object = self._createComponent(asn1Spec, tagSet, '', **options) + + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk - # TODO: - # if LOG: - # LOG('explicit tag container carries %d octets of trailing payload ' - # '(will be lost!): %s' % (len(_), debug.hexdump(_))) + return - return value + for value in decodeFun(substrate, asn1Spec, tagSet, length, **options): + yield value def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, decodeFun=None, substrateFun=None, **options): if substrateFun: - return substrateFun( - self._createComponent(asn1Spec, tagSet, '', **options), - substrate, length - ) + asn1Object = self._createComponent(asn1Spec, tagSet, '', **options) - value = decodeFun(substrate, asn1Spec, tagSet, length, **options) + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk - eooMarker = decodeFun(substrate, allowEoo=True, **options) + return - if eooMarker is eoo.endOfOctets: - return value - else: - raise error.PyAsn1Error('Missing end-of-octets terminator') + while True: + for value in decodeFun( + substrate, asn1Spec, tagSet, length, + allowEoo=True, **options): + if value is eoo.endOfOctets: + break + + yield value + + if value is eoo.endOfOctets: + break -explicitTagDecoder = ExplicitTagDecoder() +rawPayloadDecoder = RawPayloadDecoder() -class IntegerDecoder(AbstractSimpleDecoder): +class IntegerPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.Integer(0) def valueDecoder(self, substrate, asn1Spec, @@ -276,24 +135,27 @@ def valueDecoder(self, substrate, asn1Spec, if tagSet[0].tagFormat != tag.tagFormatSimple: raise error.PyAsn1Error('Simple tag format expected') - the_bytes = substrate.read(length) - if not the_bytes: - return self._createComponent(asn1Spec, tagSet, 0, **options) + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk - value = from_bytes(the_bytes, signed=True) + if not chunk: + yield self._createComponent(asn1Spec, tagSet, 0, **options) - return self._createComponent(asn1Spec, tagSet, value, **options) + value = from_bytes(chunk, signed=True) + yield self._createComponent(asn1Spec, tagSet, value, **options) -class BooleanDecoder(IntegerDecoder): + +class BooleanPayloadDecoder(IntegerPayloadDecoder): protoComponent = univ.Boolean(0) def _createComponent(self, asn1Spec, tagSet, value, **options): - return IntegerDecoder._createComponent( + return IntegerPayloadDecoder._createComponent( self, asn1Spec, tagSet, value and 1 or 0, **options) -class BitStringDecoder(AbstractSimpleDecoder): +class BitStringPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.BitString(()) supportConstructedForm = True @@ -303,24 +165,45 @@ def valueDecoder(self, substrate, asn1Spec, **options): if substrateFun: - return substrateFun(self._createComponent( - asn1Spec, tagSet, noValue, **options), substrate, length) + asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options) + + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return - if not length or _endOfStream(substrate): + if not length: + raise error.PyAsn1Error('Empty BIT STRING substrate') + + for chunk in streaming.isEndOfStream(substrate): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk + + if chunk: raise error.PyAsn1Error('Empty BIT STRING substrate') if tagSet[0].tagFormat == tag.tagFormatSimple: # XXX what tag to check? - trailingBits = ord(substrate.read(1)) + for trailingBits in streaming.read(substrate, 1, options): + if isinstance(trailingBits, SubstrateUnderrunError): + yield trailingBits + + trailingBits = ord(trailingBits) if trailingBits > 7: raise error.PyAsn1Error( 'Trailing bits overflow %s' % trailingBits ) + for chunk in streaming.read(substrate, length - 1, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk + value = self.protoComponent.fromOctetString( - substrate.read(length - 1), internalFormat=True, padding=trailingBits) + chunk, internalFormat=True, padding=trailingBits) + + yield self._createComponent(asn1Spec, tagSet, value, **options) - return self._createComponent(asn1Spec, tagSet, value, **options) + return if not self.supportConstructedForm: raise error.PyAsn1Error('Constructed encoding form prohibited ' @@ -337,8 +220,11 @@ def valueDecoder(self, substrate, asn1Spec, current_position = substrate.tell() while substrate.tell() - current_position < length: - component = decodeFun(substrate, self.protoComponent, - substrateFun=substrateFun, **options) + for component in decodeFun( + substrate, self.protoComponent, substrateFun=substrateFun, + **options): + if isinstance(component, SubstrateUnderrunError): + yield component trailingBits = oct2int(component[0]) if trailingBits > 7: @@ -351,7 +237,7 @@ def valueDecoder(self, substrate, asn1Spec, prepend=bitString, padding=trailingBits ) - return self._createComponent(asn1Spec, tagSet, bitString, **options) + yield self._createComponent(asn1Spec, tagSet, bitString, **options) def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -359,21 +245,32 @@ def indefLenValueDecoder(self, substrate, asn1Spec, **options): if substrateFun: - return substrateFun(self._createComponent(asn1Spec, tagSet, noValue, **options), substrate, length) + asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options) + + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return # All inner fragments are of the same type, treat them as octet string substrateFun = self.substrateCollector bitString = self.protoComponent.fromOctetString(null, internalFormat=True) - while True: - component = decodeFun(substrate, self.protoComponent, - substrateFun=substrateFun, - allowEoo=True, **options) + while True: # loop over fragments + + for component in decodeFun( + substrate, self.protoComponent, substrateFun=substrateFun, + allowEoo=True, **options): + + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break + if component is eoo.endOfOctets: break - if component is None: - raise error.SubstrateUnderrunError('No EOO seen before substrate ends') trailingBits = oct2int(component[0]) if trailingBits > 7: @@ -386,10 +283,10 @@ def indefLenValueDecoder(self, substrate, asn1Spec, prepend=bitString, padding=trailingBits ) - return self._createComponent(asn1Spec, tagSet, bitString, **options) + yield self._createComponent(asn1Spec, tagSet, bitString, **options) -class OctetStringDecoder(AbstractSimpleDecoder): +class OctetStringPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.OctetString('') supportConstructedForm = True @@ -398,11 +295,21 @@ def valueDecoder(self, substrate, asn1Spec, decodeFun=None, substrateFun=None, **options): if substrateFun: - return substrateFun(self._createComponent(asn1Spec, tagSet, noValue, **options), - substrate, length) + asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options) + + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return if tagSet[0].tagFormat == tag.tagFormatSimple: # XXX what tag to check? - return self._createComponent(asn1Spec, tagSet, substrate.read(length), **options) + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk + + yield self._createComponent(asn1Spec, tagSet, chunk, **options) + + return if not self.supportConstructedForm: raise error.PyAsn1Error('Constructed encoding form prohibited at %s' % self.__class__.__name__) @@ -418,12 +325,15 @@ def valueDecoder(self, substrate, asn1Spec, original_position = substrate.tell() # head = popSubstream(substrate, length) while substrate.tell() - original_position < length: - component = decodeFun(substrate, self.protoComponent, - substrateFun=substrateFun, - **options) + for component in decodeFun( + substrate, self.protoComponent, substrateFun=substrateFun, + **options): + if isinstance(component, SubstrateUnderrunError): + yield component + header += component - return self._createComponent(asn1Spec, tagSet, header, **options) + yield self._createComponent(asn1Spec, tagSet, header, **options) def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -431,31 +341,38 @@ def indefLenValueDecoder(self, substrate, asn1Spec, **options): if substrateFun and substrateFun is not self.substrateCollector: asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options) - return substrateFun(asn1Object, substrate, length) + + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return # All inner fragments are of the same type, treat them as octet string substrateFun = self.substrateCollector header = null - while True: - component = decodeFun(substrate, - self.protoComponent, - substrateFun=substrateFun, - allowEoo=True, **options) + while True: # loop over fragments + + for component in decodeFun( + substrate, self.protoComponent, substrateFun=substrateFun, + allowEoo=True, **options): + + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break + if component is eoo.endOfOctets: break - if not component: - raise error.SubstrateUnderrunError( - 'No EOO seen before substrate ends' - ) header += component - return self._createComponent(asn1Spec, tagSet, header, **options) + yield self._createComponent(asn1Spec, tagSet, header, **options) -class NullDecoder(AbstractSimpleDecoder): +class NullPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.Null('') def valueDecoder(self, substrate, asn1Spec, @@ -466,17 +383,19 @@ def valueDecoder(self, substrate, asn1Spec, if tagSet[0].tagFormat != tag.tagFormatSimple: raise error.PyAsn1Error('Simple tag format expected') - head = substrate.read(length) + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk component = self._createComponent(asn1Spec, tagSet, '', **options) - if head: + if chunk: raise error.PyAsn1Error('Unexpected %d-octet substrate for Null' % length) - return component + yield component -class ObjectIdentifierDecoder(AbstractSimpleDecoder): +class ObjectIdentifierPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.ObjectIdentifier(()) def valueDecoder(self, substrate, asn1Spec, @@ -486,17 +405,20 @@ def valueDecoder(self, substrate, asn1Spec, if tagSet[0].tagFormat != tag.tagFormatSimple: raise error.PyAsn1Error('Simple tag format expected') - head = substrate.read(length) - if not head: + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk + + if not chunk: raise error.PyAsn1Error('Empty substrate') - head = octs2ints(head) + chunk = octs2ints(chunk) oid = () index = 0 - substrateLen = len(head) + substrateLen = len(chunk) while index < substrateLen: - subId = head[index] + subId = chunk[index] index += 1 if subId < 128: oid += (subId,) @@ -510,7 +432,7 @@ def valueDecoder(self, substrate, asn1Spec, raise error.SubstrateUnderrunError( 'Short substrate for sub-OID past %s' % (oid,) ) - nextSubId = head[index] + nextSubId = chunk[index] index += 1 oid += ((subId << 7) + nextSubId,) elif subId == 128: @@ -528,12 +450,12 @@ def valueDecoder(self, substrate, asn1Spec, elif oid[0] >= 80: oid = (2, oid[0] - 80) + oid[1:] else: - raise error.PyAsn1Error('Malformed first OID octet: %s' % head[0]) + raise error.PyAsn1Error('Malformed first OID octet: %s' % chunk[0]) - return self._createComponent(asn1Spec, tagSet, oid, **options) + yield self._createComponent(asn1Spec, tagSet, oid, **options) -class RealDecoder(AbstractSimpleDecoder): +class RealPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.Real() def valueDecoder(self, substrate, asn1Spec, @@ -543,15 +465,18 @@ def valueDecoder(self, substrate, asn1Spec, if tagSet[0].tagFormat != tag.tagFormatSimple: raise error.PyAsn1Error('Simple tag format expected') - head = substrate.read(length) + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk - if not head: - return self._createComponent(asn1Spec, tagSet, 0.0, **options) + if not chunk: + yield self._createComponent(asn1Spec, tagSet, 0.0, **options) + return - fo = oct2int(head[0]) - head = head[1:] + fo = oct2int(chunk[0]) + chunk = chunk[1:] if fo & 0x80: # binary encoding - if not head: + if not chunk: raise error.PyAsn1Error("Incomplete floating-point value") if LOG: @@ -560,12 +485,12 @@ def valueDecoder(self, substrate, asn1Spec, n = (fo & 0x03) + 1 if n == 4: - n = oct2int(head[0]) - head = head[1:] + n = oct2int(chunk[0]) + chunk = chunk[1:] - eo, head = head[:n], head[n:] + eo, chunk = chunk[:n], chunk[n:] - if not eo or not head: + if not eo or not chunk: raise error.PyAsn1Error('Real exponent screwed') e = oct2int(eo[0]) & 0x80 and -1 or 0 @@ -587,10 +512,10 @@ def valueDecoder(self, substrate, asn1Spec, e *= 4 p = 0 - while head: # value + while chunk: # value p <<= 8 - p |= oct2int(head[0]) - head = head[1:] + p |= oct2int(chunk[0]) + chunk = chunk[1:] if fo & 0x40: # sign bit p = -p @@ -606,7 +531,7 @@ def valueDecoder(self, substrate, asn1Spec, value = fo & 0x01 and '-inf' or 'inf' elif fo & 0xc0 == 0: # character encoding - if not head: + if not chunk: raise error.PyAsn1Error("Incomplete floating-point value") if LOG: @@ -614,13 +539,13 @@ def valueDecoder(self, substrate, asn1Spec, try: if fo & 0x3 == 0x1: # NR1 - value = (int(head), 10, 0) + value = (int(chunk), 10, 0) elif fo & 0x3 == 0x2: # NR2 - value = float(head) + value = float(chunk) elif fo & 0x3 == 0x3: # NR3 - value = float(head) + value = float(chunk) else: raise error.SubstrateUnderrunError( @@ -637,14 +562,14 @@ def valueDecoder(self, substrate, asn1Spec, 'Unknown encoding (tag %s)' % fo ) - return self._createComponent(asn1Spec, tagSet, value, **options) + yield self._createComponent(asn1Spec, tagSet, value, **options) -class AbstractConstructedDecoder(AbstractDecoder): +class AbstractConstructedPayloadDecoder(AbstractPayloadDecoder): protoComponent = None -class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): +class ConstructedPayloadDecoderBase(AbstractConstructedPayloadDecoder): protoRecordComponent = None protoSequenceComponent = None @@ -654,36 +579,43 @@ def _getComponentTagMap(self, asn1Object, idx): def _getComponentPositionByType(self, asn1Object, tagSet, idx): raise NotImplementedError() - def _decodeComponents(self, substrate, tagSet=None, decodeFun=None, **options): + def _decodeComponentsSchemaless( + self, substrate, tagSet=None, decodeFun=None, + length=None, **options): + + asn1Object = None + components = [] componentTypes = set() - while True: - component = decodeFun(substrate, **options) - if component is eoo.endOfOctets: - break - if component is None: - # TODO: Not an error in this case? + original_position = substrate.tell() + + while length == -1 or substrate.tell() < original_position + length: + for component in decodeFun(substrate, **options): + if isinstance(component, SubstrateUnderrunError): + yield component + + if length == -1 and component is eoo.endOfOctets: break components.append(component) componentTypes.add(component.tagSet) - # Now we have to guess is it SEQUENCE/SET or SEQUENCE OF/SET OF - # The heuristics is: - # * 1+ components of different types -> likely SEQUENCE/SET - # * otherwise -> likely SEQUENCE OF/SET OF - if len(componentTypes) > 1: - protoComponent = self.protoRecordComponent + # Now we have to guess is it SEQUENCE/SET or SEQUENCE OF/SET OF + # The heuristics is: + # * 1+ components of different types -> likely SEQUENCE/SET + # * otherwise -> likely SEQUENCE OF/SET OF + if len(componentTypes) > 1: + protoComponent = self.protoRecordComponent - else: - protoComponent = self.protoSequenceComponent + else: + protoComponent = self.protoSequenceComponent - asn1Object = protoComponent.clone( - # construct tagSet from base tag from prototype ASN.1 object - # and additional tags recovered from the substrate - tagSet=tag.TagSet(protoComponent.tagSet.baseTag, *tagSet.superTags) - ) + asn1Object = protoComponent.clone( + # construct tagSet from base tag from prototype ASN.1 object + # and additional tags recovered from the substrate + tagSet=tag.TagSet(protoComponent.tagSet.baseTag, *tagSet.superTags) + ) if LOG: LOG('guessed %r container type (pass `asn1Spec` to guide the ' @@ -696,7 +628,7 @@ def _decodeComponents(self, substrate, tagSet=None, decodeFun=None, **options): matchTags=False, matchConstraints=False ) - return asn1Object + yield asn1Object def valueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -707,7 +639,7 @@ def valueDecoder(self, substrate, asn1Spec, original_position = substrate.tell() - if substrateFun is not None: + if substrateFun: if asn1Spec is not None: asn1Object = asn1Spec.clone() @@ -717,24 +649,36 @@ def valueDecoder(self, substrate, asn1Spec, else: asn1Object = self.protoRecordComponent, self.protoSequenceComponent - return substrateFun(asn1Object, substrate, length) + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return if asn1Spec is None: - asn1Object = self._decodeComponents( - substrate, tagSet=tagSet, decodeFun=decodeFun, **options - ) + for asn1Object in self._decodeComponentsSchemaless( + substrate, tagSet=tagSet, decodeFun=decodeFun, + length=length, **options): + if isinstance(asn1Object, SubstrateUnderrunError): + yield asn1Object if substrate.tell() < original_position + length: if LOG: - trailing = substrate.read() + for trailing in streaming.read(substrate, context=options): + if isinstance(trailing, SubstrateUnderrunError): + yield trailing + LOG('Unused trailing %d octets encountered: %s' % ( len(trailing), debug.hexdump(trailing))) - return asn1Object + yield asn1Object + + return asn1Object = asn1Spec.clone() asn1Object.clear() + options = self._passAsn1Object(asn1Object, options) + if asn1Spec.typeId in (univ.Sequence.typeId, univ.Set.typeId): namedTypes = asn1Spec.componentType @@ -772,7 +716,9 @@ def valueDecoder(self, substrate, asn1Spec, 'Excessive components decoded at %r' % (asn1Spec,) ) - component = decodeFun(substrate, componentType, **options) + for component in decodeFun(substrate, componentType, **options): + if isinstance(component, SubstrateUnderrunError): + yield component if not isDeterministic and namedTypes: if isSetType: @@ -859,18 +805,20 @@ def valueDecoder(self, substrate, asn1Spec, for pos, containerElement in enumerate( containerValue): - component = decodeFun( - _asSeekableStream(containerValue[pos].asOctets()), - asn1Spec=openType, **options - ) + stream = streaming.asSeekableStream(containerValue[pos].asOctets()) + + for component in decodeFun(stream, asn1Spec=openType, **options): + if isinstance(component, SubstrateUnderrunError): + yield component containerValue[pos] = component else: - component = decodeFun( - _asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets()), - asn1Spec=openType, **options - ) + stream = streaming.asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets()) + + for component in decodeFun(stream, asn1Spec=openType, **options): + if isinstance(component, SubstrateUnderrunError): + yield component asn1Object.setComponentByPosition(idx, component) @@ -880,9 +828,6 @@ def valueDecoder(self, substrate, asn1Spec, raise inconsistency else: - asn1Object = asn1Spec.clone() - asn1Object.clear() - componentType = asn1Spec.componentType if LOG: @@ -891,7 +836,10 @@ def valueDecoder(self, substrate, asn1Spec, idx = 0 while substrate.tell() - original_position < length: - component = decodeFun(substrate, componentType, **options) + for component in decodeFun(substrate, componentType, **options): + if isinstance(component, SubstrateUnderrunError): + yield component + asn1Object.setComponentByPosition( idx, component, verifyConstraints=False, @@ -900,7 +848,7 @@ def valueDecoder(self, substrate, asn1Spec, idx += 1 - return asn1Object + yield asn1Object def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -919,17 +867,27 @@ def indefLenValueDecoder(self, substrate, asn1Spec, else: asn1Object = self.protoRecordComponent, self.protoSequenceComponent - return substrateFun(asn1Object, substrate, length) + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return if asn1Spec is None: - return self._decodeComponents( - substrate, tagSet=tagSet, decodeFun=decodeFun, - **dict(options, allowEoo=True) - ) + for asn1Object in self._decodeComponentsSchemaless( + substrate, tagSet=tagSet, decodeFun=decodeFun, + length=length, **dict(options, allowEoo=True)): + if isinstance(asn1Object, SubstrateUnderrunError): + yield asn1Object + + yield asn1Object + + return asn1Object = asn1Spec.clone() asn1Object.clear() + options = self._passAsn1Object(asn1Object, options) + if asn1Spec.typeId in (univ.Sequence.typeId, univ.Set.typeId): namedTypes = asn1Object.componentType @@ -943,8 +901,10 @@ def indefLenValueDecoder(self, substrate, asn1Spec, asn1Spec)) seenIndices = set() + idx = 0 - while True: #not endOfStream(substrate): + + while True: # loop over components if len(namedTypes) <= idx: asn1Spec = None @@ -967,17 +927,21 @@ def indefLenValueDecoder(self, substrate, asn1Spec, 'Excessive components decoded at %r' % (asn1Object,) ) - component = decodeFun(substrate, asn1Spec, allowEoo=True, **options) + for component in decodeFun(substrate, asn1Spec, allowEoo=True, **options): + + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break + if component is eoo.endOfOctets: break - if component is None: - raise error.SubstrateUnderrunError( - 'No EOO seen before substrate ends' - ) if not isDeterministic and namedTypes: if isSetType: idx = namedTypes.getPositionByType(component.effectiveTagSet) + elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted: idx = namedTypes.getPositionNearType(component.effectiveTagSet, idx) @@ -995,7 +959,9 @@ def indefLenValueDecoder(self, substrate, asn1Spec, if namedTypes: if not namedTypes.requiredComponents.issubset(seenIndices): - raise error.PyAsn1Error('ASN.1 object %s has uninitialized components' % asn1Object.__class__.__name__) + raise error.PyAsn1Error( + 'ASN.1 object %s has uninitialized ' + 'components' % asn1Object.__class__.__name__) if namedTypes.hasOpenTypes: @@ -1057,20 +1023,28 @@ def indefLenValueDecoder(self, substrate, asn1Spec, for pos, containerElement in enumerate( containerValue): - component = decodeFun( - _asSeekableStream(containerValue[pos].asOctets()), - asn1Spec=openType, **dict(options, allowEoo=True) - ) + stream = streaming.asSeekableStream(containerValue[pos].asOctets()) + + for component in decodeFun(stream, asn1Spec=openType, + **dict(options, allowEoo=True)): + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break containerValue[pos] = component else: - component = decodeFun( - _asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets()), - asn1Spec=openType, **dict(options, allowEoo=True) - ) + stream = streaming.asSeekableStream(asn1Object.getComponentByPosition(idx).asOctets()) + for component in decodeFun(stream, asn1Spec=openType, + **dict(options, allowEoo=True)): + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break - if component is not eoo.endOfOctets: asn1Object.setComponentByPosition(idx, component) else: @@ -1079,9 +1053,6 @@ def indefLenValueDecoder(self, substrate, asn1Spec, raise inconsistency else: - asn1Object = asn1Spec.clone() - asn1Object.clear() - componentType = asn1Spec.componentType if LOG: @@ -1090,14 +1061,18 @@ def indefLenValueDecoder(self, substrate, asn1Spec, idx = 0 while True: - component = decodeFun(substrate, componentType, allowEoo=True, **options) + + for component in decodeFun( + substrate, componentType, allowEoo=True, **options): + + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break if component is eoo.endOfOctets: break - if component is None: - raise error.SubstrateUnderrunError( - 'No EOO seen before substrate ends' - ) asn1Object.setComponentByPosition( idx, component, @@ -1107,38 +1082,36 @@ def indefLenValueDecoder(self, substrate, asn1Spec, idx += 1 + yield asn1Object - return asn1Object - -class SequenceOrSequenceOfDecoder(UniversalConstructedTypeDecoder): +class SequenceOrSequenceOfPayloadDecoder(ConstructedPayloadDecoderBase): protoRecordComponent = univ.Sequence() protoSequenceComponent = univ.SequenceOf() -class SequenceDecoder(SequenceOrSequenceOfDecoder): +class SequencePayloadDecoder(SequenceOrSequenceOfPayloadDecoder): protoComponent = univ.Sequence() -class SequenceOfDecoder(SequenceOrSequenceOfDecoder): +class SequenceOfPayloadDecoder(SequenceOrSequenceOfPayloadDecoder): protoComponent = univ.SequenceOf() -class SetOrSetOfDecoder(UniversalConstructedTypeDecoder): +class SetOrSetOfPayloadDecoder(ConstructedPayloadDecoderBase): protoRecordComponent = univ.Set() protoSequenceComponent = univ.SetOf() -class SetDecoder(SetOrSetOfDecoder): +class SetPayloadDecoder(SetOrSetOfPayloadDecoder): protoComponent = univ.Set() - -class SetOfDecoder(SetOrSetOfDecoder): +class SetOfPayloadDecoder(SetOrSetOfPayloadDecoder): protoComponent = univ.SetOf() -class ChoiceDecoder(AbstractConstructedDecoder): +class ChoicePayloadDecoder(ConstructedPayloadDecoderBase): protoComponent = univ.Choice() def valueDecoder(self, substrate, asn1Spec, @@ -1154,24 +1127,31 @@ def valueDecoder(self, substrate, asn1Spec, asn1Object = asn1Spec.clone() if substrateFun: - return substrateFun(asn1Object, substrate, length) + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk + + return + + options = self._passAsn1Object(asn1Object, options) if asn1Object.tagSet == tagSet: if LOG: LOG('decoding %s as explicitly tagged CHOICE' % (tagSet,)) - component = decodeFun( - substrate, asn1Object.componentTagMap, **options - ) + for component in decodeFun( + substrate, asn1Object.componentTagMap, **options): + if isinstance(component, SubstrateUnderrunError): + yield component else: if LOG: LOG('decoding %s as untagged CHOICE' % (tagSet,)) - component = decodeFun( - substrate, asn1Object.componentTagMap, - tagSet, length, state, **options - ) + for component in decodeFun( + substrate, asn1Object.componentTagMap, tagSet, length, + state, **options): + if isinstance(component, SubstrateUnderrunError): + yield component effectiveTagSet = component.effectiveTagSet @@ -1185,7 +1165,7 @@ def valueDecoder(self, substrate, asn1Spec, innerFlag=False ) - return asn1Object + yield asn1Object def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -1193,53 +1173,67 @@ def indefLenValueDecoder(self, substrate, asn1Spec, **options): if asn1Spec is None: asn1Object = self.protoComponent.clone(tagSet=tagSet) + else: asn1Object = asn1Spec.clone() if substrateFun: - return substrateFun(asn1Object, substrate, length) + for chunk in substrateFun(asn1Object, substrate, length, options): + yield chunk - if asn1Object.tagSet == tagSet: - if LOG: - LOG('decoding %s as explicitly tagged CHOICE' % (tagSet,)) + return - component = decodeFun( - substrate, asn1Object.componentType.tagMapUnique, **options - ) + options = self._passAsn1Object(asn1Object, options) - # eat up EOO marker - eooMarker = decodeFun( - substrate, allowEoo=True, **options - ) + isTagged = asn1Object.tagSet == tagSet - if eooMarker is not eoo.endOfOctets: - raise error.PyAsn1Error('No EOO seen before substrate ends') + if LOG: + LOG('decoding %s as %stagged CHOICE' % ( + tagSet, isTagged and 'explicitly ' or 'un')) - else: - if LOG: - LOG('decoding %s as untagged CHOICE' % (tagSet,)) + while True: - component = decodeFun( - substrate, asn1Object.componentType.tagMapUnique, - tagSet, length, state, **options - ) + if isTagged: + iterator = decodeFun( + substrate, asn1Object.componentType.tagMapUnique, + **dict(options, allowEoo=True)) - effectiveTagSet = component.effectiveTagSet + else: + iterator = decodeFun( + substrate, asn1Object.componentType.tagMapUnique, + tagSet, length, state, **dict(options, allowEoo=True)) - if LOG: - LOG('decoded component %s, effective tag set %s' % (component, effectiveTagSet)) + for component in iterator: - asn1Object.setComponentByType( - effectiveTagSet, component, - verifyConstraints=False, - matchTags=False, matchConstraints=False, - innerFlag=False - ) + if isinstance(component, SubstrateUnderrunError): + yield component - return asn1Object + if component is eoo.endOfOctets: + break + effectiveTagSet = component.effectiveTagSet + + if LOG: + LOG('decoded component %s, effective tag set ' + '%s' % (component, effectiveTagSet)) -class AnyDecoder(AbstractSimpleDecoder): + asn1Object.setComponentByType( + effectiveTagSet, component, + verifyConstraints=False, + matchTags=False, matchConstraints=False, + innerFlag=False + ) + + if not isTagged: + break + + if not isTagged or component is eoo.endOfOctets: + break + + yield asn1Object + + +class AnyPayloadDecoder(AbstractSimplePayloadDecoder): protoComponent = univ.Any() def valueDecoder(self, substrate, asn1Spec, @@ -1256,22 +1250,32 @@ def valueDecoder(self, substrate, asn1Spec, isUntagged = tagSet != asn1Spec.tagSet if isUntagged: - fullPosition = substrate._markedPosition + fullPosition = substrate.markedPosition currentPosition = substrate.tell() substrate.seek(fullPosition, os.SEEK_SET) - length += (currentPosition - fullPosition) + length += currentPosition - fullPosition if LOG: - LOG('decoding as untagged ANY, substrate %s' % debug.hexdump(_peek(substrate, length))) + for chunk in streaming.peek(substrate, length): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk + LOG('decoding as untagged ANY, substrate ' + '%s' % debug.hexdump(chunk)) if substrateFun: - return substrateFun(self._createComponent(asn1Spec, tagSet, noValue, **options), - substrate, length) + for chunk in substrateFun( + self._createComponent(asn1Spec, tagSet, noValue, **options), + substrate, length, options): + yield chunk + + return - head = substrate.read(length) + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk - return self._createComponent(asn1Spec, tagSet, head, **options) + yield self._createComponent(asn1Spec, tagSet, chunk, **options) def indefLenValueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, @@ -1288,28 +1292,36 @@ def indefLenValueDecoder(self, substrate, asn1Spec, if isTagged: # tagged Any type -- consume header substrate - header = null + chunk = null if LOG: LOG('decoding as tagged ANY') else: # TODO: Seems not to be tested - fullPosition = substrate._markedPosition + fullPosition = substrate.markedPosition currentPosition = substrate.tell() substrate.seek(fullPosition, os.SEEK_SET) - header = substrate.read(currentPosition - fullPosition) + for chunk in streaming.read(substrate, currentPosition - fullPosition, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk if LOG: - LOG('decoding as untagged ANY, header substrate %s' % debug.hexdump(header)) + LOG('decoding as untagged ANY, header substrate %s' % debug.hexdump(chunk)) # Any components do not inherit initial tag asn1Spec = self.protoComponent if substrateFun and substrateFun is not self.substrateCollector: - asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options) - return substrateFun(asn1Object, header + substrate, length + len(header)) + asn1Object = self._createComponent( + asn1Spec, tagSet, noValue, **options) + + for chunk in substrateFun( + asn1Object, chunk + substrate, length + len(chunk), options): + yield chunk + + return if LOG: LOG('assembling constructed serialization') @@ -1317,130 +1329,134 @@ def indefLenValueDecoder(self, substrate, asn1Spec, # All inner fragments are of the same type, treat them as octet string substrateFun = self.substrateCollector - while True: - component = decodeFun(substrate, asn1Spec, - substrateFun=substrateFun, - allowEoo=True, **options) + while True: # loop over fragments + + for component in decodeFun( + substrate, asn1Spec, substrateFun=substrateFun, + allowEoo=True, **options): + + if isinstance(component, SubstrateUnderrunError): + yield component + + if component is eoo.endOfOctets: + break + if component is eoo.endOfOctets: break - if not component: - raise error.SubstrateUnderrunError( - 'No EOO seen before substrate ends' - ) - header += component + chunk += component if substrateFun: - return header # TODO: Weird + yield chunk # TODO: Weird else: - return self._createComponent(asn1Spec, tagSet, header, **options) + yield self._createComponent(asn1Spec, tagSet, chunk, **options) # character string types -class UTF8StringDecoder(OctetStringDecoder): +class UTF8StringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.UTF8String() -class NumericStringDecoder(OctetStringDecoder): +class NumericStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.NumericString() -class PrintableStringDecoder(OctetStringDecoder): +class PrintableStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.PrintableString() -class TeletexStringDecoder(OctetStringDecoder): +class TeletexStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.TeletexString() -class VideotexStringDecoder(OctetStringDecoder): +class VideotexStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.VideotexString() -class IA5StringDecoder(OctetStringDecoder): +class IA5StringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.IA5String() -class GraphicStringDecoder(OctetStringDecoder): +class GraphicStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.GraphicString() -class VisibleStringDecoder(OctetStringDecoder): +class VisibleStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.VisibleString() -class GeneralStringDecoder(OctetStringDecoder): +class GeneralStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.GeneralString() -class UniversalStringDecoder(OctetStringDecoder): +class UniversalStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.UniversalString() -class BMPStringDecoder(OctetStringDecoder): +class BMPStringPayloadDecoder(OctetStringPayloadDecoder): protoComponent = char.BMPString() # "useful" types -class ObjectDescriptorDecoder(OctetStringDecoder): +class ObjectDescriptorPayloadDecoder(OctetStringPayloadDecoder): protoComponent = useful.ObjectDescriptor() -class GeneralizedTimeDecoder(OctetStringDecoder): +class GeneralizedTimePayloadDecoder(OctetStringPayloadDecoder): protoComponent = useful.GeneralizedTime() -class UTCTimeDecoder(OctetStringDecoder): +class UTCTimePayloadDecoder(OctetStringPayloadDecoder): protoComponent = useful.UTCTime() -tagMap = { - univ.Integer.tagSet: IntegerDecoder(), - univ.Boolean.tagSet: BooleanDecoder(), - univ.BitString.tagSet: BitStringDecoder(), - univ.OctetString.tagSet: OctetStringDecoder(), - univ.Null.tagSet: NullDecoder(), - univ.ObjectIdentifier.tagSet: ObjectIdentifierDecoder(), - univ.Enumerated.tagSet: IntegerDecoder(), - univ.Real.tagSet: RealDecoder(), - univ.Sequence.tagSet: SequenceOrSequenceOfDecoder(), # conflicts with SequenceOf - univ.Set.tagSet: SetOrSetOfDecoder(), # conflicts with SetOf - univ.Choice.tagSet: ChoiceDecoder(), # conflicts with Any +TAG_MAP = { + univ.Integer.tagSet: IntegerPayloadDecoder(), + univ.Boolean.tagSet: BooleanPayloadDecoder(), + univ.BitString.tagSet: BitStringPayloadDecoder(), + univ.OctetString.tagSet: OctetStringPayloadDecoder(), + univ.Null.tagSet: NullPayloadDecoder(), + univ.ObjectIdentifier.tagSet: ObjectIdentifierPayloadDecoder(), + univ.Enumerated.tagSet: IntegerPayloadDecoder(), + univ.Real.tagSet: RealPayloadDecoder(), + univ.Sequence.tagSet: SequenceOrSequenceOfPayloadDecoder(), # conflicts with SequenceOf + univ.Set.tagSet: SetOrSetOfPayloadDecoder(), # conflicts with SetOf + univ.Choice.tagSet: ChoicePayloadDecoder(), # conflicts with Any # character string types - char.UTF8String.tagSet: UTF8StringDecoder(), - char.NumericString.tagSet: NumericStringDecoder(), - char.PrintableString.tagSet: PrintableStringDecoder(), - char.TeletexString.tagSet: TeletexStringDecoder(), - char.VideotexString.tagSet: VideotexStringDecoder(), - char.IA5String.tagSet: IA5StringDecoder(), - char.GraphicString.tagSet: GraphicStringDecoder(), - char.VisibleString.tagSet: VisibleStringDecoder(), - char.GeneralString.tagSet: GeneralStringDecoder(), - char.UniversalString.tagSet: UniversalStringDecoder(), - char.BMPString.tagSet: BMPStringDecoder(), + char.UTF8String.tagSet: UTF8StringPayloadDecoder(), + char.NumericString.tagSet: NumericStringPayloadDecoder(), + char.PrintableString.tagSet: PrintableStringPayloadDecoder(), + char.TeletexString.tagSet: TeletexStringPayloadDecoder(), + char.VideotexString.tagSet: VideotexStringPayloadDecoder(), + char.IA5String.tagSet: IA5StringPayloadDecoder(), + char.GraphicString.tagSet: GraphicStringPayloadDecoder(), + char.VisibleString.tagSet: VisibleStringPayloadDecoder(), + char.GeneralString.tagSet: GeneralStringPayloadDecoder(), + char.UniversalString.tagSet: UniversalStringPayloadDecoder(), + char.BMPString.tagSet: BMPStringPayloadDecoder(), # useful types - useful.ObjectDescriptor.tagSet: ObjectDescriptorDecoder(), - useful.GeneralizedTime.tagSet: GeneralizedTimeDecoder(), - useful.UTCTime.tagSet: UTCTimeDecoder() + useful.ObjectDescriptor.tagSet: ObjectDescriptorPayloadDecoder(), + useful.GeneralizedTime.tagSet: GeneralizedTimePayloadDecoder(), + useful.UTCTime.tagSet: UTCTimePayloadDecoder() } # Type-to-codec map for ambiguous ASN.1 types -typeMap = { - univ.Set.typeId: SetDecoder(), - univ.SetOf.typeId: SetOfDecoder(), - univ.Sequence.typeId: SequenceDecoder(), - univ.SequenceOf.typeId: SequenceOfDecoder(), - univ.Choice.typeId: ChoiceDecoder(), - univ.Any.typeId: AnyDecoder() +TYPE_MAP = { + univ.Set.typeId: SetPayloadDecoder(), + univ.SetOf.typeId: SetOfPayloadDecoder(), + univ.Sequence.typeId: SequencePayloadDecoder(), + univ.SequenceOf.typeId: SequenceOfPayloadDecoder(), + univ.Choice.typeId: ChoicePayloadDecoder(), + univ.Any.typeId: AnyPayloadDecoder() } # Put in non-ambiguous types for faster codec lookup -for typeDecoder in tagMap.values(): +for typeDecoder in TAG_MAP.values(): if typeDecoder.protoComponent is not None: typeId = typeDecoder.protoComponent.__class__.typeId - if typeId is not None and typeId not in typeMap: - typeMap[typeId] = typeDecoder + if typeId is not None and typeId not in TYPE_MAP: + TYPE_MAP[typeId] = typeDecoder (stDecodeTag, @@ -1455,16 +1471,19 @@ class UTCTimeDecoder(OctetStringDecoder): stStop) = [x for x in range(10)] -class Decoder(object): +class SingleItemDecoder(object): defaultErrorState = stErrorCondition #defaultErrorState = stDumpRawValue - defaultRawDecoder = AnyDecoder() + defaultRawDecoder = AnyPayloadDecoder() + supportIndefLength = True - # noinspection PyDefaultArgument - def __init__(self, tagMap, typeMap={}): - self.__tagMap = tagMap - self.__typeMap = typeMap + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP + + def __init__(self, tagMap=None, typeMap=None): + self.__tagMap = tagMap or self.TAG_MAP + self.__typeMap = typeMap or self.TYPE_MAP # Tag & TagSet objects caches self.__tagCache = {} self.__tagSetCache = {} @@ -1475,29 +1494,37 @@ def __call__(self, substrate, asn1Spec=None, decodeFun=None, substrateFun=None, **options): - if LOG: - LOG('decoder called at scope %s with state %d, working with up to %s octets of substrate: %s' % (debug.scope, state, length, substrate)) - allowEoo = options.pop('allowEoo', False) + if LOG: + LOG('decoder called at scope %s with state %d, working with up ' + 'to %s octets of substrate: ' + '%s' % (debug.scope, state, length, substrate)) + # Look for end-of-octets sentinel if allowEoo and self.supportIndefLength: - eoo_candidate = substrate.read(2) + + for eoo_candidate in streaming.read(substrate, 2, options): + if isinstance(eoo_candidate, SubstrateUnderrunError): + yield eoo_candidate + if eoo_candidate == self.__eooSentinel: if LOG: LOG('end-of-octets sentinel found') - return eoo.endOfOctets + yield eoo.endOfOctets + return + else: substrate.seek(-2, os.SEEK_CUR) - value = noValue - tagMap = self.__tagMap typeMap = self.__typeMap tagCache = self.__tagCache tagSetCache = self.__tagSetCache - substrate._markedPosition = substrate.tell() + value = noValue + + substrate.markedPosition = substrate.tell() while state is not stStop: @@ -1505,9 +1532,9 @@ def __call__(self, substrate, asn1Spec=None, # Decode tag isShortTag = True - firstByte = substrate.read(1) - if not firstByte: - return None + for firstByte in streaming.read(substrate, 1, options): + if isinstance(firstByte, SubstrateUnderrunError): + yield firstByte firstOctet = ord(firstByte) @@ -1526,15 +1553,20 @@ def __call__(self, substrate, asn1Spec=None, tagId = 0 while True: - integerByte = substrate.read(1) + for integerByte in streaming.read(substrate, 1, options): + if isinstance(integerByte, SubstrateUnderrunError): + yield integerByte + if not integerByte: raise error.SubstrateUnderrunError( 'Short octet stream on long tag decoding' ) + integerTag = ord(integerByte) lengthOctetIdx += 1 tagId <<= 7 tagId |= (integerTag & 0x7F) + if not integerTag & 0x80: break @@ -1568,12 +1600,11 @@ def __call__(self, substrate, asn1Spec=None, if state is stDecodeLength: # Decode length - try: - firstOctet = ord(substrate.read(1)) - except: - raise error.SubstrateUnderrunError( - 'Short octet stream on length decoding' - ) + for firstOctet in streaming.read(substrate, 1, options): + if isinstance(firstOctet, SubstrateUnderrunError): + yield firstOctet + + firstOctet = ord(firstOctet) if firstOctet < 128: length = firstOctet @@ -1581,7 +1612,10 @@ def __call__(self, substrate, asn1Spec=None, elif firstOctet > 128: size = firstOctet & 0x7F # encoded in size bytes - encodedLength = list(substrate.read(size)) + for encodedLength in streaming.read(substrate, size, options): + if isinstance(encodedLength, SubstrateUnderrunError): + yield encodedLength + encodedLength = list(encodedLength) # missing check on maximum size, which shouldn't be a # problem, we can handle more than is possible if len(encodedLength) != size: @@ -1726,25 +1760,30 @@ def __call__(self, substrate, asn1Spec=None, original_position = substrate.tell() if length == -1: # indef length - value = concreteDecoder.indefLenValueDecoder( - substrate, asn1Spec, - tagSet, length, stGetValueDecoder, - self, substrateFun, - **options - ) + for value in concreteDecoder.indefLenValueDecoder( + substrate, asn1Spec, + tagSet, length, stGetValueDecoder, + self, substrateFun, **options): + if isinstance(value, SubstrateUnderrunError): + yield value + else: - value = concreteDecoder.valueDecoder( - substrate, asn1Spec, - tagSet, length, stGetValueDecoder, - self, substrateFun, - **options - ) - bytes_read = substrate.tell() - original_position - if bytes_read != length: - raise PyAsn1Error("Read %s bytes instead of expected %s." % (bytes_read, length)) + for value in concreteDecoder.valueDecoder( + substrate, asn1Spec, + tagSet, length, stGetValueDecoder, + self, substrateFun, **options): + if isinstance(value, SubstrateUnderrunError): + yield value + + bytesRead = substrate.tell() - original_position + if bytesRead != length: + raise PyAsn1Error( + "Read %s bytes instead of expected %s." % (bytesRead, length)) if LOG: - LOG('codec %s yields type %s, value:\n%s\n...' % (concreteDecoder.__class__.__name__, value.__class__.__name__, isinstance(value, base.Asn1Item) and value.prettyPrint() or value)) + LOG('codec %s yields type %s, value:\n%s\n...' % ( + concreteDecoder.__class__.__name__, value.__class__.__name__, + isinstance(value, base.Asn1Item) and value.prettyPrint() or value)) state = stStop break @@ -1754,7 +1793,7 @@ def __call__(self, substrate, asn1Spec=None, tagSet[0].tagFormat == tag.tagFormatConstructed and tagSet[0].tagClass != tag.tagClassUniversal): # Assume explicit tagging - concreteDecoder = explicitTagDecoder + concreteDecoder = rawPayloadDecoder state = stDecodeValue else: @@ -1781,25 +1820,187 @@ def __call__(self, substrate, asn1Spec=None, debug.scope.pop() LOG('decoder left scope %s, call completed' % debug.scope) - return value + yield value -_decode = Decoder(tagMap, typeMap) +class StreamingDecoder(object): + """Create an iterator that turns BER/CER/DER byte stream into ASN.1 objects. + On each iteration, consume whatever BER/CER/DER serialization is + available in the `substrate` stream-like object and turns it into + one or more, possibly nested, ASN.1 objects. -def decodeStream(substrate, asn1Spec=None, **kwargs): - """Iterator of objects in a substrate.""" - # TODO: This should become `decode` after API-breaking approved - try: - substrate = _asSeekableStream(substrate) - except TypeError: - raise PyAsn1Error - while True: - result = _decode(substrate, asn1Spec, **kwargs) - if result is None: - break - yield result - # TODO: Check about eoo.endOfOctets? + Parameters + ---------- + substrate: :py:class:`file`, :py:class:`io.BytesIO` + BER/CER/DER serialization in form of a byte stream + + Keyword Args + ------------ + asn1Spec: :py:class:`~pyasn1.type.base.PyAsn1Item` + A pyasn1 type object to act as a template guiding the decoder. + Depending on the ASN.1 structure being decoded, `asn1Spec` may + or may not be required. One of the reasons why `asn1Spec` may + me required is that ASN.1 structure is encoded in the *IMPLICIT* + tagging mode. + + Yields + ------ + : :py:class:`~pyasn1.type.base.PyAsn1Item`, :py:class:`~pyasn1.error.SubstrateUnderrunError` + Decoded ASN.1 object (possibly, nested) or + :py:class:`~pyasn1.error.SubstrateUnderrunError` object indicating + insufficient BER/CER/DER serialization on input to fully recover ASN.1 + objects from it. + + In the latter case the caller is advised to ensure some more data in + the input stream, then call the iterator again. The decoder will resume + the decoding process using the newly arrived data. + + The `context` property of :py:class:`~pyasn1.error.SubstrateUnderrunError` + object might hold a reference to the partially populated ASN.1 object + being reconstructed. + + Raises + ------ + ~pyasn1.error.PyAsn1Error, ~pyasn1.error.EndOfStreamError + `PyAsn1Error` on deserialization error, `EndOfStreamError` on + premature stream closure. + + Examples + -------- + Decode BER serialisation without ASN.1 schema + + .. code-block:: pycon + + >>> stream = io.BytesIO( + ... b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03') + >>> + >>> for asn1Object in StreamingDecoder(stream): + ... print(asn1Object) + >>> + SequenceOf: + 1 2 3 + + Decode BER serialisation with ASN.1 schema + + .. code-block:: pycon + + >>> stream = io.BytesIO( + ... b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03') + >>> + >>> schema = SequenceOf(componentType=Integer()) + >>> + >>> decoder = StreamingDecoder(stream, asn1Spec=schema) + >>> for asn1Object in decoder: + ... print(asn1Object) + >>> + SequenceOf: + 1 2 3 + """ + + SINGLE_ITEM_DECODER = SingleItemDecoder + + def __init__(self, substrate, asn1Spec=None, **kwargs): + self._substrate = streaming.asSeekableStream(substrate) + self._asn1Spec = asn1Spec + self._options = kwargs + self._decoder = self.SINGLE_ITEM_DECODER() + + def __iter__(self): + while True: + for asn1Object in self._decoder( + self._substrate, self._asn1Spec, **self._options): + yield asn1Object + + for chunk in streaming.isEndOfStream(self._substrate): + if isinstance(chunk, SubstrateUnderrunError): + yield + + break + + if chunk: + break + + +class Decoder(object): + """Create a BER decoder object. + + Parse BER/CER/DER octet-stream into one, possibly nested, ASN.1 object. + """ + STREAMING_DECODER = StreamingDecoder + + @classmethod + def __call__(cls, substrate, asn1Spec=None, **kwargs): + """Turns BER/CER/DER octet stream into an ASN.1 object. + + Takes BER/CER/DER octet-stream in form of :py:class:`bytes` (Python 3) + or :py:class:`str` (Python 2) and decode it into an ASN.1 object + (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) which + may be a scalar or an arbitrary nested structure. + + Parameters + ---------- + substrate: :py:class:`bytes` (Python 3) or :py:class:`str` (Python 2) + BER/CER/DER octet-stream to parse + + Keyword Args + ------------ + asn1Spec: :py:class:`~pyasn1.type.base.PyAsn1Item` + A pyasn1 type object (:py:class:`~pyasn1.type.base.PyAsn1Item` + derivative) to act as a template guiding the decoder. + Depending on the ASN.1 structure being decoded, `asn1Spec` may or + may not be required. Most common reason for it to require is that + ASN.1 structure is encoded in *IMPLICIT* tagging mode. + + Returns + ------- + : :py:class:`tuple` + A tuple of :py:class:`~pyasn1.type.base.PyAsn1Item` object + recovered from BER/CER/DER substrate and the unprocessed trailing + portion of the `substrate` (may be empty) + + Raises + ------ + : :py:class:`~pyasn1.error.PyAsn1Error` + :py:class:`~pyasn1.error.SubstrateUnderrunError` on insufficient + input or :py:class:`~pyasn1.error.PyAsn1Error` on decoding error. + + Examples + -------- + Decode BER/CER/DER serialisation without ASN.1 schema + + .. code-block:: pycon + + >>> s, unprocessed = decode(b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03') + >>> str(s) + SequenceOf: + 1 2 3 + + Decode BER/CER/DER serialisation with ASN.1 schema + + .. code-block:: pycon + + >>> seq = SequenceOf(componentType=Integer()) + >>> s, unprocessed = decode( + b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03', asn1Spec=seq) + >>> str(s) + SequenceOf: + 1 2 3 + + """ + substrate = streaming.asSeekableStream(substrate) + + for asn1Object in cls.STREAMING_DECODER(substrate, asn1Spec, **kwargs): + if isinstance(asn1Object, SubstrateUnderrunError): + raise error.SubstrateUnderrunError('Short substrate on input') + + try: + tail = next(streaming.read(substrate)) + + except error.EndOfStreamError: + tail = null + + return asn1Object, tail #: Turns BER octet stream into an ASN.1 object. @@ -1831,6 +2032,11 @@ def decodeStream(substrate, asn1Spec=None, **kwargs): #: ~pyasn1.error.PyAsn1Error, ~pyasn1.error.SubstrateUnderrunError #: On decoding errors #: +#: Notes +#: ----- +#: This function is deprecated. Please use :py:class:`Decoder` or +#: :py:class:`StreamingDecoder` class instance. +#: #: Examples #: -------- #: Decode BER serialisation without ASN.1 schema @@ -1852,13 +2058,4 @@ def decodeStream(substrate, asn1Spec=None, **kwargs): #: SequenceOf: #: 1 2 3 #: -def decode(substrate, asn1Spec=None, **kwargs): - # TODO: Temporary solution before merging with upstream - # It preserves the original API - substrate = _asSeekableStream(substrate) - value = _decode(substrate, asn1Spec=asn1Spec, **kwargs) - return value, substrate.read() - - -# XXX -# non-recursive decoding; return position rather than substrate +decode = Decoder() diff --git a/pyasn1/codec/ber/encoder.py b/pyasn1/codec/ber/encoder.py index 778aa867..6b77b703 100644 --- a/pyasn1/codec/ber/encoder.py +++ b/pyasn1/codec/ber/encoder.py @@ -17,7 +17,7 @@ from pyasn1.type import univ from pyasn1.type import useful -__all__ = ['encode'] +__all__ = ['Encoder', 'encode'] LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_ENCODER) @@ -706,7 +706,7 @@ def encodeValue(self, value, asn1Spec, encodeFun, **options): return value, not options.get('defMode', True), True -tagMap = { +TAG_MAP = { eoo.endOfOctets.tagSet: EndOfOctetsEncoder(), univ.Boolean.tagSet: BooleanEncoder(), univ.Integer.tagSet: IntegerEncoder(), @@ -739,7 +739,7 @@ def encodeValue(self, value, asn1Spec, encodeFun, **options): } # Put in ambiguous & non-ambiguous types for faster codec lookup -typeMap = { +TYPE_MAP = { univ.Boolean.typeId: BooleanEncoder(), univ.Integer.typeId: IntegerEncoder(), univ.BitString.typeId: BitStringEncoder(), @@ -774,14 +774,16 @@ def encodeValue(self, value, asn1Spec, encodeFun, **options): } -class Encoder(object): +class SingleItemEncoder(object): fixedDefLengthMode = None fixedChunkSize = None - # noinspection PyDefaultArgument - def __init__(self, tagMap, typeMap={}): - self.__tagMap = tagMap - self.__typeMap = typeMap + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP + + def __init__(self, tagMap=None, typeMap=None): + self.__tagMap = tagMap or self.TAG_MAP + self.__typeMap = typeMap or self.TYPE_MAP def __call__(self, value, asn1Spec=None, **options): try: @@ -795,8 +797,11 @@ def __call__(self, value, asn1Spec=None, **options): 'and "asn1Spec" not given' % (value,)) if LOG: - LOG('encoder called in %sdef mode, chunk size %s for ' - 'type %s, value:\n%s' % (not options.get('defMode', True) and 'in' or '', options.get('maxChunkSize', 0), asn1Spec is None and value.prettyPrintType() or asn1Spec.prettyPrintType(), value)) + LOG('encoder called in %sdef mode, chunk size %s for type %s, ' + 'value:\n%s' % (not options.get('defMode', True) and 'in' or '', + options.get('maxChunkSize', 0), + asn1Spec is None and value.prettyPrintType() or + asn1Spec.prettyPrintType(), value)) if self.fixedDefLengthMode is not None: options.update(defMode=self.fixedDefLengthMode) @@ -804,12 +809,12 @@ def __call__(self, value, asn1Spec=None, **options): if self.fixedChunkSize is not None: options.update(maxChunkSize=self.fixedChunkSize) - try: concreteEncoder = self.__typeMap[typeId] if LOG: - LOG('using value codec %s chosen by type ID %s' % (concreteEncoder.__class__.__name__, typeId)) + LOG('using value codec %s chosen by type ID ' + '%s' % (concreteEncoder.__class__.__name__, typeId)) except KeyError: if asn1Spec is None: @@ -827,15 +832,28 @@ def __call__(self, value, asn1Spec=None, **options): raise error.PyAsn1Error('No encoder for %r (%s)' % (value, tagSet)) if LOG: - LOG('using value codec %s chosen by tagSet %s' % (concreteEncoder.__class__.__name__, tagSet)) + LOG('using value codec %s chosen by tagSet ' + '%s' % (concreteEncoder.__class__.__name__, tagSet)) substrate = concreteEncoder.encode(value, asn1Spec, self, **options) if LOG: - LOG('codec %s built %s octets of substrate: %s\nencoder completed' % (concreteEncoder, len(substrate), debug.hexdump(substrate))) + LOG('codec %s built %s octets of substrate: %s\nencoder ' + 'completed' % (concreteEncoder, len(substrate), + debug.hexdump(substrate))) return substrate + +class Encoder(object): + SINGLE_ITEM_ENCODER = SingleItemEncoder + + @classmethod + def __call__(cls, pyObject, asn1Spec=None, **options): + singleItemEncoder = cls.SINGLE_ITEM_ENCODER() + return singleItemEncoder(pyObject, asn1Spec=asn1Spec, **options) + + #: Turns ASN.1 object into BER octet stream. #: #: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) @@ -887,4 +905,4 @@ def __call__(self, value, asn1Spec=None, **options): #: >>> encode(seq) #: b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03' #: -encode = Encoder(tagMap, typeMap) +encode = Encoder() diff --git a/pyasn1/codec/cer/decoder.py b/pyasn1/codec/cer/decoder.py index b709313a..08f9ec81 100644 --- a/pyasn1/codec/cer/decoder.py +++ b/pyasn1/codec/cer/decoder.py @@ -4,79 +4,89 @@ # Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # -from io import BytesIO - from pyasn1 import error +from pyasn1.codec import streaming from pyasn1.codec.ber import decoder -from pyasn1.codec.ber.decoder import _asSeekableStream from pyasn1.compat.octets import oct2int from pyasn1.type import univ -__all__ = ['decode', 'decodeStream'] +__all__ = ['decode', 'StreamingDecoder'] + +SubstrateUnderrunError = error.SubstrateUnderrunError -class BooleanDecoder(decoder.AbstractSimpleDecoder): +class BooleanPayloadDecoder(decoder.AbstractSimplePayloadDecoder): protoComponent = univ.Boolean(0) def valueDecoder(self, substrate, asn1Spec, tagSet=None, length=None, state=None, decodeFun=None, substrateFun=None, **options): - head = substrate.read(1) - if not head or length != 1: + + if length != 1: raise error.PyAsn1Error('Not single-octet Boolean payload') - byte = oct2int(head[0]) + + for chunk in streaming.read(substrate, length, options): + if isinstance(chunk, SubstrateUnderrunError): + yield chunk + + byte = oct2int(chunk[0]) + # CER/DER specifies encoding of TRUE as 0xFF and FALSE as 0x0, while # BER allows any non-zero value as TRUE; cf. sections 8.2.2. and 11.1 # in https://www.itu.int/ITU-T/studygroups/com17/languages/X.690-0207.pdf if byte == 0xff: value = 1 + elif byte == 0x00: value = 0 + else: raise error.PyAsn1Error('Unexpected Boolean payload: %s' % byte) - return self._createComponent(asn1Spec, tagSet, value, **options) + + yield self._createComponent(asn1Spec, tagSet, value, **options) + # TODO: prohibit non-canonical encoding -BitStringDecoder = decoder.BitStringDecoder -OctetStringDecoder = decoder.OctetStringDecoder -RealDecoder = decoder.RealDecoder - -tagMap = decoder.tagMap.copy() -tagMap.update( - {univ.Boolean.tagSet: BooleanDecoder(), - univ.BitString.tagSet: BitStringDecoder(), - univ.OctetString.tagSet: OctetStringDecoder(), - univ.Real.tagSet: RealDecoder()} +BitStringPayloadDecoder = decoder.BitStringPayloadDecoder +OctetStringPayloadDecoder = decoder.OctetStringPayloadDecoder +RealPayloadDecoder = decoder.RealPayloadDecoder + +TAG_MAP = decoder.TAG_MAP.copy() +TAG_MAP.update( + {univ.Boolean.tagSet: BooleanPayloadDecoder(), + univ.BitString.tagSet: BitStringPayloadDecoder(), + univ.OctetString.tagSet: OctetStringPayloadDecoder(), + univ.Real.tagSet: RealPayloadDecoder()} ) -typeMap = decoder.typeMap.copy() +TYPE_MAP = decoder.TYPE_MAP.copy() # Put in non-ambiguous types for faster codec lookup -for typeDecoder in tagMap.values(): +for typeDecoder in TAG_MAP.values(): if typeDecoder.protoComponent is not None: typeId = typeDecoder.protoComponent.__class__.typeId - if typeId is not None and typeId not in typeMap: - typeMap[typeId] = typeDecoder + if typeId is not None and typeId not in TYPE_MAP: + TYPE_MAP[typeId] = typeDecoder -class Decoder(decoder.Decoder): - pass +class SingleItemDecoder(decoder.SingleItemDecoder): + __doc__ = decoder.SingleItemDecoder.__doc__ + + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP -_decode = Decoder(tagMap, typeMap) +class StreamingDecoder(decoder.StreamingDecoder): + __doc__ = decoder.StreamingDecoder.__doc__ + SINGLE_ITEM_DECODER = SingleItemDecoder + + +class Decoder(decoder.Decoder): + __doc__ = decoder.Decoder.__doc__ -def decodeStream(substrate, asn1Spec=None, **kwargs): - """Iterator of objects in a substrate.""" - # TODO: This should become `decode` after API-breaking approved - substrate = _asSeekableStream(substrate) - while True: - result = _decode(substrate, asn1Spec, **kwargs) - if result is None: - break - yield result - # TODO: Check about eoo.endOfOctets? + STREAMING_DECODER = StreamingDecoder #: Turns CER octet stream into an ASN.1 object. @@ -129,9 +139,4 @@ def decodeStream(substrate, asn1Spec=None, **kwargs): #: SequenceOf: #: 1 2 3 #: -def decode(substrate, asn1Spec=None, **kwargs): - # TODO: Temporary solution before merging with upstream - # It preserves the original API - substrate = _asSeekableStream(substrate) - value = _decode(substrate, asn1Spec=asn1Spec, **kwargs) - return value, substrate.read() +decode = Decoder() diff --git a/pyasn1/codec/cer/encoder.py b/pyasn1/codec/cer/encoder.py index 935b6965..9e6cdac8 100644 --- a/pyasn1/codec/cer/encoder.py +++ b/pyasn1/codec/cer/encoder.py @@ -10,7 +10,7 @@ from pyasn1.type import univ from pyasn1.type import useful -__all__ = ['encode'] +__all__ = ['Encoder', 'encode'] class BooleanEncoder(encoder.IntegerEncoder): @@ -234,8 +234,9 @@ class SequenceEncoder(encoder.SequenceEncoder): omitEmptyOptionals = True -tagMap = encoder.tagMap.copy() -tagMap.update({ +TAG_MAP = encoder.TAG_MAP.copy() + +TAG_MAP.update({ univ.Boolean.tagSet: BooleanEncoder(), univ.Real.tagSet: RealEncoder(), useful.GeneralizedTime.tagSet: GeneralizedTimeEncoder(), @@ -245,8 +246,9 @@ class SequenceEncoder(encoder.SequenceEncoder): univ.Sequence.typeId: SequenceEncoder() }) -typeMap = encoder.typeMap.copy() -typeMap.update({ +TYPE_MAP = encoder.TYPE_MAP.copy() + +TYPE_MAP.update({ univ.Boolean.typeId: BooleanEncoder(), univ.Real.typeId: RealEncoder(), useful.GeneralizedTime.typeId: GeneralizedTimeEncoder(), @@ -259,10 +261,18 @@ class SequenceEncoder(encoder.SequenceEncoder): }) -class Encoder(encoder.Encoder): +class SingleItemEncoder(encoder.SingleItemEncoder): fixedDefLengthMode = False fixedChunkSize = 1000 + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP + + +class Encoder(encoder.Encoder): + SINGLE_ITEM_ENCODER = SingleItemEncoder + + #: Turns ASN.1 object into CER octet stream. #: #: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) @@ -308,6 +318,6 @@ class Encoder(encoder.Encoder): #: >>> encode(seq) #: b'0\x80\x02\x01\x01\x02\x01\x02\x02\x01\x03\x00\x00' #: -encode = Encoder(tagMap, typeMap) +encode = Encoder() # EncoderFactory queries class instance and builds a map of tags -> encoders diff --git a/pyasn1/codec/der/decoder.py b/pyasn1/codec/der/decoder.py index e3399703..b9526c3e 100644 --- a/pyasn1/codec/der/decoder.py +++ b/pyasn1/codec/der/decoder.py @@ -4,59 +4,59 @@ # Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # -from io import BytesIO - -from pyasn1.codec.ber.decoder import _asSeekableStream from pyasn1.codec.cer import decoder from pyasn1.type import univ -__all__ = ['decode', 'decodeStream'] +__all__ = ['decode', 'StreamingDecoder'] -class BitStringDecoder(decoder.BitStringDecoder): +class BitStringPayloadDecoder(decoder.BitStringPayloadDecoder): supportConstructedForm = False -class OctetStringDecoder(decoder.OctetStringDecoder): +class OctetStringPayloadDecoder(decoder.OctetStringPayloadDecoder): supportConstructedForm = False + # TODO: prohibit non-canonical encoding -RealDecoder = decoder.RealDecoder +RealPayloadDecoder = decoder.RealPayloadDecoder -tagMap = decoder.tagMap.copy() -tagMap.update( - {univ.BitString.tagSet: BitStringDecoder(), - univ.OctetString.tagSet: OctetStringDecoder(), - univ.Real.tagSet: RealDecoder()} +TAG_MAP = decoder.TAG_MAP.copy() +TAG_MAP.update( + {univ.BitString.tagSet: BitStringPayloadDecoder(), + univ.OctetString.tagSet: OctetStringPayloadDecoder(), + univ.Real.tagSet: RealPayloadDecoder()} ) -typeMap = decoder.typeMap.copy() +TYPE_MAP = decoder.TYPE_MAP.copy() # Put in non-ambiguous types for faster codec lookup -for typeDecoder in tagMap.values(): +for typeDecoder in TAG_MAP.values(): if typeDecoder.protoComponent is not None: typeId = typeDecoder.protoComponent.__class__.typeId - if typeId is not None and typeId not in typeMap: - typeMap[typeId] = typeDecoder + if typeId is not None and typeId not in TYPE_MAP: + TYPE_MAP[typeId] = typeDecoder -class Decoder(decoder.Decoder): +class SingleItemDecoder(decoder.SingleItemDecoder): + __doc__ = decoder.SingleItemDecoder.__doc__ + + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP + supportIndefLength = False -_decode = Decoder(tagMap, decoder.typeMap) +class StreamingDecoder(decoder.StreamingDecoder): + __doc__ = decoder.StreamingDecoder.__doc__ + + SINGLE_ITEM_DECODER = SingleItemDecoder -def decodeStream(substrate, asn1Spec=None, **kwargs): - """Iterator of objects in a substrate.""" - # TODO: This should become `decode` after API-breaking approved - substrate = _asSeekableStream(substrate) - while True: - result = _decode(substrate, asn1Spec, **kwargs) - if result is None: - break - yield result - # TODO: Check about eoo.endOfOctets? +class Decoder(decoder.Decoder): + __doc__ = decoder.Decoder.__doc__ + + STREAMING_DECODER = StreamingDecoder #: Turns DER octet stream into an ASN.1 object. @@ -109,9 +109,4 @@ def decodeStream(substrate, asn1Spec=None, **kwargs): #: SequenceOf: #: 1 2 3 #: -def decode(substrate, asn1Spec=None, **kwargs): - # TODO: Temporary solution before merging with upstream - # It preserves the original API - substrate = _asSeekableStream(substrate) - value = _decode(substrate, asn1Spec=asn1Spec, **kwargs) - return value, substrate.read() \ No newline at end of file +decode = Decoder() diff --git a/pyasn1/codec/der/encoder.py b/pyasn1/codec/der/encoder.py index 90e982da..1a6af82b 100644 --- a/pyasn1/codec/der/encoder.py +++ b/pyasn1/codec/der/encoder.py @@ -8,7 +8,7 @@ from pyasn1.codec.cer import encoder from pyasn1.type import univ -__all__ = ['encode'] +__all__ = ['Encoder', 'encode'] class SetEncoder(encoder.SetEncoder): @@ -42,23 +42,34 @@ def _componentSortKey(componentAndType): else: return compType.tagSet -tagMap = encoder.tagMap.copy() -tagMap.update({ + +TAG_MAP = encoder.TAG_MAP.copy() + +TAG_MAP.update({ # Set & SetOf have same tags univ.Set.tagSet: SetEncoder() }) -typeMap = encoder.typeMap.copy() -typeMap.update({ +TYPE_MAP = encoder.TYPE_MAP.copy() + +TYPE_MAP.update({ # Set & SetOf have same tags univ.Set.typeId: SetEncoder() }) -class Encoder(encoder.Encoder): +class SingleItemEncoder(encoder.SingleItemEncoder): fixedDefLengthMode = True fixedChunkSize = 0 + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP + + +class Encoder(encoder.Encoder): + SINGLE_ITEM_ENCODER = SingleItemEncoder + + #: Turns ASN.1 object into DER octet stream. #: #: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) @@ -104,4 +115,4 @@ class Encoder(encoder.Encoder): #: >>> encode(seq) #: b'0\t\x02\x01\x01\x02\x01\x02\x02\x01\x03' #: -encode = Encoder(tagMap, typeMap) +encode = Encoder() diff --git a/pyasn1/codec/native/decoder.py b/pyasn1/codec/native/decoder.py index 104b92e6..ecb1b161 100644 --- a/pyasn1/codec/native/decoder.py +++ b/pyasn1/codec/native/decoder.py @@ -17,17 +17,17 @@ LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER) -class AbstractScalarDecoder(object): +class AbstractScalarPayloadDecoder(object): def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): return asn1Spec.clone(pyObject) -class BitStringDecoder(AbstractScalarDecoder): +class BitStringPayloadDecoder(AbstractScalarPayloadDecoder): def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): return asn1Spec.clone(univ.BitString.fromBinaryString(pyObject)) -class SequenceOrSetDecoder(object): +class SequenceOrSetPayloadDecoder(object): def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): asn1Value = asn1Spec.clone() @@ -40,7 +40,7 @@ def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): return asn1Value -class SequenceOfOrSetOfDecoder(object): +class SequenceOfOrSetOfPayloadDecoder(object): def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): asn1Value = asn1Spec.clone() @@ -50,7 +50,7 @@ def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): return asn1Value -class ChoiceDecoder(object): +class ChoicePayloadDecoder(object): def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): asn1Value = asn1Spec.clone() @@ -64,87 +64,92 @@ def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): return asn1Value -tagMap = { - univ.Integer.tagSet: AbstractScalarDecoder(), - univ.Boolean.tagSet: AbstractScalarDecoder(), - univ.BitString.tagSet: BitStringDecoder(), - univ.OctetString.tagSet: AbstractScalarDecoder(), - univ.Null.tagSet: AbstractScalarDecoder(), - univ.ObjectIdentifier.tagSet: AbstractScalarDecoder(), - univ.Enumerated.tagSet: AbstractScalarDecoder(), - univ.Real.tagSet: AbstractScalarDecoder(), - univ.Sequence.tagSet: SequenceOrSetDecoder(), # conflicts with SequenceOf - univ.Set.tagSet: SequenceOrSetDecoder(), # conflicts with SetOf - univ.Choice.tagSet: ChoiceDecoder(), # conflicts with Any +TAG_MAP = { + univ.Integer.tagSet: AbstractScalarPayloadDecoder(), + univ.Boolean.tagSet: AbstractScalarPayloadDecoder(), + univ.BitString.tagSet: BitStringPayloadDecoder(), + univ.OctetString.tagSet: AbstractScalarPayloadDecoder(), + univ.Null.tagSet: AbstractScalarPayloadDecoder(), + univ.ObjectIdentifier.tagSet: AbstractScalarPayloadDecoder(), + univ.Enumerated.tagSet: AbstractScalarPayloadDecoder(), + univ.Real.tagSet: AbstractScalarPayloadDecoder(), + univ.Sequence.tagSet: SequenceOrSetPayloadDecoder(), # conflicts with SequenceOf + univ.Set.tagSet: SequenceOrSetPayloadDecoder(), # conflicts with SetOf + univ.Choice.tagSet: ChoicePayloadDecoder(), # conflicts with Any # character string types - char.UTF8String.tagSet: AbstractScalarDecoder(), - char.NumericString.tagSet: AbstractScalarDecoder(), - char.PrintableString.tagSet: AbstractScalarDecoder(), - char.TeletexString.tagSet: AbstractScalarDecoder(), - char.VideotexString.tagSet: AbstractScalarDecoder(), - char.IA5String.tagSet: AbstractScalarDecoder(), - char.GraphicString.tagSet: AbstractScalarDecoder(), - char.VisibleString.tagSet: AbstractScalarDecoder(), - char.GeneralString.tagSet: AbstractScalarDecoder(), - char.UniversalString.tagSet: AbstractScalarDecoder(), - char.BMPString.tagSet: AbstractScalarDecoder(), + char.UTF8String.tagSet: AbstractScalarPayloadDecoder(), + char.NumericString.tagSet: AbstractScalarPayloadDecoder(), + char.PrintableString.tagSet: AbstractScalarPayloadDecoder(), + char.TeletexString.tagSet: AbstractScalarPayloadDecoder(), + char.VideotexString.tagSet: AbstractScalarPayloadDecoder(), + char.IA5String.tagSet: AbstractScalarPayloadDecoder(), + char.GraphicString.tagSet: AbstractScalarPayloadDecoder(), + char.VisibleString.tagSet: AbstractScalarPayloadDecoder(), + char.GeneralString.tagSet: AbstractScalarPayloadDecoder(), + char.UniversalString.tagSet: AbstractScalarPayloadDecoder(), + char.BMPString.tagSet: AbstractScalarPayloadDecoder(), # useful types - useful.ObjectDescriptor.tagSet: AbstractScalarDecoder(), - useful.GeneralizedTime.tagSet: AbstractScalarDecoder(), - useful.UTCTime.tagSet: AbstractScalarDecoder() + useful.ObjectDescriptor.tagSet: AbstractScalarPayloadDecoder(), + useful.GeneralizedTime.tagSet: AbstractScalarPayloadDecoder(), + useful.UTCTime.tagSet: AbstractScalarPayloadDecoder() } # Put in ambiguous & non-ambiguous types for faster codec lookup -typeMap = { - univ.Integer.typeId: AbstractScalarDecoder(), - univ.Boolean.typeId: AbstractScalarDecoder(), - univ.BitString.typeId: BitStringDecoder(), - univ.OctetString.typeId: AbstractScalarDecoder(), - univ.Null.typeId: AbstractScalarDecoder(), - univ.ObjectIdentifier.typeId: AbstractScalarDecoder(), - univ.Enumerated.typeId: AbstractScalarDecoder(), - univ.Real.typeId: AbstractScalarDecoder(), +TYPE_MAP = { + univ.Integer.typeId: AbstractScalarPayloadDecoder(), + univ.Boolean.typeId: AbstractScalarPayloadDecoder(), + univ.BitString.typeId: BitStringPayloadDecoder(), + univ.OctetString.typeId: AbstractScalarPayloadDecoder(), + univ.Null.typeId: AbstractScalarPayloadDecoder(), + univ.ObjectIdentifier.typeId: AbstractScalarPayloadDecoder(), + univ.Enumerated.typeId: AbstractScalarPayloadDecoder(), + univ.Real.typeId: AbstractScalarPayloadDecoder(), # ambiguous base types - univ.Set.typeId: SequenceOrSetDecoder(), - univ.SetOf.typeId: SequenceOfOrSetOfDecoder(), - univ.Sequence.typeId: SequenceOrSetDecoder(), - univ.SequenceOf.typeId: SequenceOfOrSetOfDecoder(), - univ.Choice.typeId: ChoiceDecoder(), - univ.Any.typeId: AbstractScalarDecoder(), + univ.Set.typeId: SequenceOrSetPayloadDecoder(), + univ.SetOf.typeId: SequenceOfOrSetOfPayloadDecoder(), + univ.Sequence.typeId: SequenceOrSetPayloadDecoder(), + univ.SequenceOf.typeId: SequenceOfOrSetOfPayloadDecoder(), + univ.Choice.typeId: ChoicePayloadDecoder(), + univ.Any.typeId: AbstractScalarPayloadDecoder(), # character string types - char.UTF8String.typeId: AbstractScalarDecoder(), - char.NumericString.typeId: AbstractScalarDecoder(), - char.PrintableString.typeId: AbstractScalarDecoder(), - char.TeletexString.typeId: AbstractScalarDecoder(), - char.VideotexString.typeId: AbstractScalarDecoder(), - char.IA5String.typeId: AbstractScalarDecoder(), - char.GraphicString.typeId: AbstractScalarDecoder(), - char.VisibleString.typeId: AbstractScalarDecoder(), - char.GeneralString.typeId: AbstractScalarDecoder(), - char.UniversalString.typeId: AbstractScalarDecoder(), - char.BMPString.typeId: AbstractScalarDecoder(), + char.UTF8String.typeId: AbstractScalarPayloadDecoder(), + char.NumericString.typeId: AbstractScalarPayloadDecoder(), + char.PrintableString.typeId: AbstractScalarPayloadDecoder(), + char.TeletexString.typeId: AbstractScalarPayloadDecoder(), + char.VideotexString.typeId: AbstractScalarPayloadDecoder(), + char.IA5String.typeId: AbstractScalarPayloadDecoder(), + char.GraphicString.typeId: AbstractScalarPayloadDecoder(), + char.VisibleString.typeId: AbstractScalarPayloadDecoder(), + char.GeneralString.typeId: AbstractScalarPayloadDecoder(), + char.UniversalString.typeId: AbstractScalarPayloadDecoder(), + char.BMPString.typeId: AbstractScalarPayloadDecoder(), # useful types - useful.ObjectDescriptor.typeId: AbstractScalarDecoder(), - useful.GeneralizedTime.typeId: AbstractScalarDecoder(), - useful.UTCTime.typeId: AbstractScalarDecoder() + useful.ObjectDescriptor.typeId: AbstractScalarPayloadDecoder(), + useful.GeneralizedTime.typeId: AbstractScalarPayloadDecoder(), + useful.UTCTime.typeId: AbstractScalarPayloadDecoder() } -class Decoder(object): +class SingleItemDecoder(object): + + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP - # noinspection PyDefaultArgument - def __init__(self, tagMap, typeMap): - self.__tagMap = tagMap - self.__typeMap = typeMap + def __init__(self, tagMap=None, typeMap=None): + self.__tagMap = tagMap or self.TAG_MAP + self.__typeMap = typeMap or self.TYPE_MAP def __call__(self, pyObject, asn1Spec, **options): if LOG: debug.scope.push(type(pyObject).__name__) - LOG('decoder called at scope %s, working with type %s' % (debug.scope, type(pyObject).__name__)) + LOG('decoder called at scope %s, working with ' + 'type %s' % (debug.scope, type(pyObject).__name__)) if asn1Spec is None or not isinstance(asn1Spec, base.Asn1Item): - raise error.PyAsn1Error('asn1Spec is not valid (should be an instance of an ASN.1 Item, not %s)' % asn1Spec.__class__.__name__) + raise error.PyAsn1Error( + 'asn1Spec is not valid (should be an instance of an ASN.1 ' + 'Item, not %s)' % asn1Spec.__class__.__name__) try: valueDecoder = self.__typeMap[asn1Spec.typeId] @@ -155,21 +160,35 @@ def __call__(self, pyObject, asn1Spec, **options): try: valueDecoder = self.__tagMap[baseTagSet] + except KeyError: raise error.PyAsn1Error('Unknown ASN.1 tag %s' % asn1Spec.tagSet) if LOG: - LOG('calling decoder %s on Python type %s <%s>' % (type(valueDecoder).__name__, type(pyObject).__name__, repr(pyObject))) + LOG('calling decoder %s on Python type %s ' + '<%s>' % (type(valueDecoder).__name__, + type(pyObject).__name__, repr(pyObject))) value = valueDecoder(pyObject, asn1Spec, self, **options) if LOG: - LOG('decoder %s produced ASN.1 type %s <%s>' % (type(valueDecoder).__name__, type(value).__name__, repr(value))) + LOG('decoder %s produced ASN.1 type %s ' + '<%s>' % (type(valueDecoder).__name__, + type(value).__name__, repr(value))) debug.scope.pop() return value +class Decoder(object): + SINGLE_ITEM_DECODER = SingleItemDecoder + + @classmethod + def __call__(cls, pyObject, asn1Spec=None, **kwargs): + singleItemDecoder = cls.SINGLE_ITEM_DECODER() + return singleItemDecoder(pyObject, asn1Spec=asn1Spec, **kwargs) + + #: Turns Python objects of built-in types into ASN.1 objects. #: #: Takes Python objects of built-in types and turns them into a tree of @@ -210,4 +229,4 @@ def __call__(self, pyObject, asn1Spec, **options): #: SequenceOf: #: 1 2 3 #: -decode = Decoder(tagMap, typeMap) +decode = Decoder() diff --git a/pyasn1/codec/native/encoder.py b/pyasn1/codec/native/encoder.py index 4318abde..a3e17a9b 100644 --- a/pyasn1/codec/native/encoder.py +++ b/pyasn1/codec/native/encoder.py @@ -107,7 +107,7 @@ def encode(self, value, encodeFun, **options): return value.asOctets() -tagMap = { +TAG_MAP = { univ.Boolean.tagSet: BooleanEncoder(), univ.Integer.tagSet: IntegerEncoder(), univ.BitString.tagSet: BitStringEncoder(), @@ -140,7 +140,7 @@ def encode(self, value, encodeFun, **options): # Put in ambiguous & non-ambiguous types for faster codec lookup -typeMap = { +TYPE_MAP = { univ.Boolean.typeId: BooleanEncoder(), univ.Integer.typeId: IntegerEncoder(), univ.BitString.typeId: BitStringEncoder(), @@ -175,20 +175,24 @@ def encode(self, value, encodeFun, **options): } -class Encoder(object): +class SingleItemEncoder(object): + + TAG_MAP = TAG_MAP + TYPE_MAP = TYPE_MAP - # noinspection PyDefaultArgument - def __init__(self, tagMap, typeMap={}): - self.__tagMap = tagMap - self.__typeMap = typeMap + def __init__(self, tagMap=None, typeMap=None): + self.__tagMap = tagMap or self.TAG_MAP + self.__typeMap = typeMap or self.TYPE_MAP def __call__(self, value, **options): if not isinstance(value, base.Asn1Item): - raise error.PyAsn1Error('value is not valid (should be an instance of an ASN.1 Item)') + raise error.PyAsn1Error( + 'value is not valid (should be an instance of an ASN.1 Item)') if LOG: debug.scope.push(type(value).__name__) - LOG('encoder called for type %s <%s>' % (type(value).__name__, value.prettyPrint())) + LOG('encoder called for type %s ' + '<%s>' % (type(value).__name__, value.prettyPrint())) tagSet = value.tagSet @@ -197,7 +201,8 @@ def __call__(self, value, **options): except KeyError: # use base type for codec lookup to recover untagged types - baseTagSet = tag.TagSet(value.tagSet.baseTag, value.tagSet.baseTag) + baseTagSet = tag.TagSet( + value.tagSet.baseTag, value.tagSet.baseTag) try: concreteEncoder = self.__tagMap[baseTagSet] @@ -206,17 +211,28 @@ def __call__(self, value, **options): raise error.PyAsn1Error('No encoder for %s' % (value,)) if LOG: - LOG('using value codec %s chosen by %s' % (concreteEncoder.__class__.__name__, tagSet)) + LOG('using value codec %s chosen by ' + '%s' % (concreteEncoder.__class__.__name__, tagSet)) pyObject = concreteEncoder.encode(value, self, **options) if LOG: - LOG('encoder %s produced: %s' % (type(concreteEncoder).__name__, repr(pyObject))) + LOG('encoder %s produced: ' + '%s' % (type(concreteEncoder).__name__, repr(pyObject))) debug.scope.pop() return pyObject +class Encoder(object): + SINGLE_ITEM_ENCODER = SingleItemEncoder + + @classmethod + def __call__(cls, pyObject, asn1Spec=None, **kwargs): + singleItemEncoder = cls.SINGLE_ITEM_ENCODER() + return singleItemEncoder(pyObject, asn1Spec=asn1Spec, **kwargs) + + #: Turns ASN.1 object into a Python built-in type object(s). #: #: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) @@ -253,4 +269,4 @@ def __call__(self, value, **options): #: >>> encode(seq) #: [1, 2, 3] #: -encode = Encoder(tagMap, typeMap) +encode = SingleItemEncoder() diff --git a/pyasn1/codec/streaming.py b/pyasn1/codec/streaming.py new file mode 100644 index 00000000..18896772 --- /dev/null +++ b/pyasn1/codec/streaming.py @@ -0,0 +1,240 @@ +# +# This file is part of pyasn1 software. +# +# Copyright (c) 2005-2019, Ilya Etingof +# License: http://snmplabs.com/pyasn1/license.html +# +import io +import os +import sys + +from pyasn1 import error +from pyasn1.type import univ + +_PY2 = sys.version_info < (3,) + + +class CachingStreamWrapper(io.IOBase): + """Wrapper around non-seekable streams. + + Note that the implementation is tied to the decoder, + not checking for dangerous arguments for the sake + of performance. + + The read bytes are kept in an internal cache until + setting _markedPosition which may reset the cache. + """ + def __init__(self, raw): + self._raw = raw + self._cache = io.BytesIO() + self._markedPosition = 0 + + def peek(self, n): + result = self.read(n) + self._cache.seek(-len(result), os.SEEK_CUR) + return result + + def seekable(self): + return True + + def seek(self, n=-1, whence=os.SEEK_SET): + # Note that this not safe for seeking forward. + return self._cache.seek(n, whence) + + def read(self, n=-1): + read_from_cache = self._cache.read(n) + if n != -1: + n -= len(read_from_cache) + if not n: # 0 bytes left to read + return read_from_cache + + read_from_raw = self._raw.read(n) + + self._cache.write(read_from_raw) + + return read_from_cache + read_from_raw + + @property + def markedPosition(self): + """Position where the currently processed element starts. + + This is used for back-tracking in SingleItemDecoder.__call__ + and (indefLen)ValueDecoder and should not be used for other purposes. + The client is not supposed to ever seek before this position. + """ + return self._markedPosition + + @markedPosition.setter + def markedPosition(self, value): + # By setting the value, we ensure we won't seek back before it. + # `value` should be the same as the current position + # We don't check for this for performance reasons. + self._markedPosition = value + + # Whenever we set _marked_position, we know for sure + # that we will not return back, and thus it is + # safe to drop all cached data. + if self._cache.tell() > io.DEFAULT_BUFFER_SIZE: + self._cache = io.BytesIO(self._cache.read()) + self._markedPosition = 0 + + def tell(self): + return self._cache.tell() + + +def asSeekableStream(substrate): + """Convert object to seekable byte-stream. + + Parameters + ---------- + substrate: :py:class:`bytes` or :py:class:`io.IOBase` or :py:class:`univ.OctetString` + + Returns + ------- + : :py:class:`io.IOBase` + + Raises + ------ + : :py:class:`~pyasn1.error.PyAsn1Error` + If the supplied substrate cannot be converted to a seekable stream. + """ + if isinstance(substrate, bytes): + return io.BytesIO(substrate) + + elif isinstance(substrate, univ.OctetString): + return io.BytesIO(substrate.asOctets()) + + try: + # Special case: impossible to set attributes on `file` built-in + if _PY2 and isinstance(substrate, file): + return io.BufferedReader(substrate) + + elif substrate.seekable(): # Will fail for most invalid types + return substrate + + else: + return CachingStreamWrapper(substrate) + + except AttributeError: + raise error.UnsupportedSubstrateError( + "Cannot convert " + substrate.__class__.__name__ + + " to a seekable bit stream.") + + +def isEndOfStream(substrate): + """Check whether we have reached the end of a stream. + + Although it is more effective to read and catch exceptions, this + function + + Parameters + ---------- + substrate: :py:class:`IOBase` + Stream to check + + Returns + ------- + : :py:class:`bool` + """ + if isinstance(substrate, io.BytesIO): + cp = substrate.tell() + substrate.seek(0, os.SEEK_END) + result = substrate.tell() == cp + substrate.seek(cp, os.SEEK_SET) + yield result + + else: + received = substrate.read(1) + if received is None: + yield + + if received: + substrate.seek(-1, os.SEEK_CUR) + + yield not received + + +def peek(substrate, size=-1): + """Peek the stream. + + Parameters + ---------- + substrate: :py:class:`IOBase` + Stream to read from. + + size: :py:class:`int` + How many bytes to peek (-1 = all available) + + Returns + ------- + : :py:class:`bytes` or :py:class:`str` + The return type depends on Python major version + """ + if hasattr(substrate, "peek"): + received = substrate.peek(size) + if received is None: + yield + + while len(received) < size: + yield + + yield received + + else: + current_position = substrate.tell() + try: + for chunk in read(substrate, size): + yield chunk + + finally: + substrate.seek(current_position) + + +def read(substrate, size=-1, context=None): + """Read from the stream. + + Parameters + ---------- + substrate: :py:class:`IOBase` + Stream to read from. + + Keyword parameters + ------------------ + size: :py:class:`int` + How many bytes to read (-1 = all available) + + context: :py:class:`dict` + Opaque caller context will be attached to exception objects created + by this function. + + Yields + ------ + : :py:class:`bytes` or :py:class:`str` or None + Returns read data or :py:class:`~pyasn1.error.SubstrateUnderrunError` + object if no `size` bytes is readily available in the stream. The + data type depends on Python major version + + Raises + ------ + : :py:class:`~pyasn1.error.EndOfStreamError` + Input stream is exhausted + """ + while True: + # this will block unless stream is non-blocking + received = substrate.read(size) + if received is None: # non-blocking stream can do this + yield error.SubstrateUnderrunError(context=context) + + elif size != 0 and not received: # end-of-stream + raise error.EndOfStreamError(context=context) + + elif len(received) < size: + substrate.seek(-len(received), os.SEEK_CUR) + + # behave like a non-blocking stream + yield error.SubstrateUnderrunError(context=context) + + else: + break + + yield received diff --git a/pyasn1/error.py b/pyasn1/error.py index 85a31ff2..08ec1b39 100644 --- a/pyasn1/error.py +++ b/pyasn1/error.py @@ -12,7 +12,36 @@ class PyAsn1Error(Exception): `PyAsn1Error` is the base exception class (based on :class:`Exception`) that represents all possible ASN.1 related errors. + + Parameters + ---------- + args: + Opaque positional parameters + + Keyword Args + ------------ + kwargs: + Opaque keyword parameters + """ + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + + @property + def context(self): + """Return exception context + + When exception object is created, the caller can supply some opaque + context for the upper layers to better understand the cause of the + exception. + + Returns + ------- + : :py:class:`dict` + Dict holding context specific data + """ + return self._kwargs.get('context', {}) class ValueConstraintError(PyAsn1Error): @@ -34,6 +63,14 @@ class SubstrateUnderrunError(PyAsn1Error): """ +class EndOfStreamError(SubstrateUnderrunError): + """ASN.1 data structure deserialization error + + The `EndOfStreamError` exception indicates the condition of the input + stream has been closed. + """ + + class UnsupportedSubstrateError(PyAsn1Error): """Unsupported substrate type to parse as ASN.1 data.""" diff --git a/tests/codec/__main__.py b/tests/codec/__main__.py index 7a4cf207..dbd744ae 100644 --- a/tests/codec/__main__.py +++ b/tests/codec/__main__.py @@ -11,7 +11,8 @@ import unittest suite = unittest.TestLoader().loadTestsFromNames( - ['tests.codec.ber.__main__.suite', + ['tests.codec.streaming.__main__.suite', + 'tests.codec.ber.__main__.suite', 'tests.codec.cer.__main__.suite', 'tests.codec.der.__main__.suite', 'tests.codec.native.__main__.suite'] diff --git a/tests/codec/ber/test_decoder.py b/tests/codec/ber/test_decoder.py index e72e025b..2430ff44 100644 --- a/tests/codec/ber/test_decoder.py +++ b/tests/codec/ber/test_decoder.py @@ -23,10 +23,11 @@ from pyasn1.type import opentype from pyasn1.type import univ from pyasn1.type import char +from pyasn1.codec import streaming from pyasn1.codec.ber import decoder from pyasn1.codec.ber import eoo from pyasn1.compat.octets import ints2octs, str2octs, null -from pyasn1.error import PyAsn1Error, SubstrateUnderrunError, UnsupportedSubstrateError +from pyasn1 import error class LargeTagDecoderTestCase(BaseTestCase): @@ -78,7 +79,7 @@ def testSpec(self): decoder.decode( ints2octs((2, 1, 12)), asn1Spec=univ.Null() ) == (12, null) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong asn1Spec worked out' @@ -89,7 +90,7 @@ def testSpec(self): def testTagFormat(self): try: decoder.decode(ints2octs((34, 1, 12))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -111,7 +112,7 @@ def testFalse(self): def testTagFormat(self): try: decoder.decode(ints2octs((33, 1, 1))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -138,24 +139,22 @@ def testIndefModeChunked(self): ints2octs((35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0)) ) == ((1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1), null) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testDefModeChunkedSubst(self): - # assert decoder.decode( - # ints2octs((35, 8, 3, 2, 0, 169, 3, 2, 1, 138)), - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((3, 2, 0, 169, 3, 2, 1, 138)), str2octs('')) + def testDefModeChunkedSubst(self): + assert decoder.decode( + ints2octs((35, 8, 3, 2, 0, 169, 3, 2, 1, 138)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((3, 2, 0, 169, 3, 2, 1, 138)), str2octs('')) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testIndefModeChunkedSubst(self): - # assert decoder.decode( - # ints2octs((35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0)), - # substrateFun=lambda a, b, c: (b, str2octs('')) - # ) == (ints2octs((3, 2, 0, 169, 3, 2, 1, 138, 0, 0)), str2octs('')) + def testIndefModeChunkedSubst(self): + assert decoder.decode( + ints2octs((35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((3, 2, 0, 169, 3, 2, 1, 138, 0, 0)), str2octs('')) def testTypeChecking(self): try: decoder.decode(ints2octs((35, 4, 2, 2, 42, 42))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'accepted mis-encoded bit-string constructed out of an integer' @@ -183,22 +182,20 @@ def testIndefModeChunked(self): ints2octs((36, 128, 4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120, 0, 0)) ) == (str2octs('Quick brown fox'), null) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testDefModeChunkedSubst(self): - # assert decoder.decode( - # ints2octs( - # (36, 23, 4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120)), - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120)), str2octs('')) + def testDefModeChunkedSubst(self): + assert decoder.decode( + ints2octs( + (36, 23, 4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120)), str2octs('')) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testIndefModeChunkedSubst(self): - # assert decoder.decode( - # ints2octs((36, 128, 4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, - # 120, 0, 0)), - # substrateFun=lambda a, b, c: (b, str2octs('')) - # ) == (ints2octs( - # (4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120, 0, 0)), str2octs('')) + def testIndefModeChunkedSubst(self): + assert decoder.decode( + ints2octs((36, 128, 4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, + 120, 0, 0)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs( + (4, 4, 81, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 4, 111, 119, 110, 32, 4, 3, 102, 111, 120, 0, 0)), str2octs('')) class ExpTaggedOctetStringDecoderTestCase(BaseTestCase): @@ -246,22 +243,20 @@ def testIndefModeChunked(self): assert self.o.tagSet == o.tagSet assert self.o.isSameTypeWith(o) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testDefModeSubst(self): - # assert decoder.decode( - # ints2octs((101, 17, 4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120)), - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120)), str2octs('')) + def testDefModeSubst(self): + assert decoder.decode( + ints2octs((101, 17, 4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120)), str2octs('')) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testIndefModeSubst(self): - # assert decoder.decode( - # ints2octs(( - # 101, 128, 36, 128, 4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 0, - # 0, 0, 0)), - # substrateFun=lambda a, b, c: (b, str2octs('')) - # ) == (ints2octs( - # (36, 128, 4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 0, 0, 0, 0)), str2octs('')) + def testIndefModeSubst(self): + assert decoder.decode( + ints2octs(( + 101, 128, 36, 128, 4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 0, + 0, 0, 0)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs( + (36, 128, 4, 15, 81, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 0, 0, 0, 0)), str2octs('')) class NullDecoderTestCase(BaseTestCase): @@ -271,7 +266,7 @@ def testNull(self): def testTagFormat(self): try: decoder.decode(ints2octs((37, 0))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -340,7 +335,7 @@ def testLeading0x80Case1(self): decoder.decode( ints2octs((6, 5, 85, 4, 128, 129, 0)) ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'Leading 0x80 tolerated' @@ -350,7 +345,7 @@ def testLeading0x80Case2(self): decoder.decode( ints2octs((6, 7, 1, 0x80, 0x80, 0x80, 0x80, 0x80, 0x7F)) ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'Leading 0x80 tolerated' @@ -360,7 +355,7 @@ def testLeading0x80Case3(self): decoder.decode( ints2octs((6, 2, 0x80, 1)) ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'Leading 0x80 tolerated' @@ -370,7 +365,7 @@ def testLeading0x80Case4(self): decoder.decode( ints2octs((6, 2, 0x80, 0x7F)) ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'Leading 0x80 tolerated' @@ -378,7 +373,7 @@ def testLeading0x80Case4(self): def testTagFormat(self): try: decoder.decode(ints2octs((38, 1, 239))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -386,7 +381,7 @@ def testTagFormat(self): def testZeroLength(self): try: decoder.decode(ints2octs((6, 0, 0))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'zero length tolerated' @@ -394,7 +389,7 @@ def testZeroLength(self): def testIndefiniteLength(self): try: decoder.decode(ints2octs((6, 128, 0))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'indefinite length tolerated' @@ -402,7 +397,7 @@ def testIndefiniteLength(self): def testReservedLength(self): try: decoder.decode(ints2octs((6, 255, 0))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'reserved length tolerated' @@ -479,7 +474,7 @@ def testEmpty(self): def testTagFormat(self): try: decoder.decode(ints2octs((41, 0))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -487,7 +482,7 @@ def testTagFormat(self): def testShortEncoding(self): try: decoder.decode(ints2octs((9, 1, 131))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'accepted too-short real' @@ -684,27 +679,25 @@ def testWithOptionalAndDefaultedIndefModeChunked(self): ints2octs((48, 128, 5, 0, 36, 128, 4, 4, 113, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 3, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)) ) == (self.s, null) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testWithOptionalAndDefaultedDefModeSubst(self): - # assert decoder.decode( - # ints2octs((48, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), str2octs('')) - - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testWithOptionalAndDefaultedIndefModeSubst(self): - # assert decoder.decode( - # ints2octs((48, 128, 5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), - # substrateFun=lambda a, b, c: (b, str2octs('')) - # ) == (ints2octs( - # (5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), str2octs('')) + def testWithOptionalAndDefaultedDefModeSubst(self): + assert decoder.decode( + ints2octs((48, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), str2octs('')) + + def testWithOptionalAndDefaultedIndefModeSubst(self): + assert decoder.decode( + ints2octs((48, 128, 5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs( + (5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), str2octs('')) def testTagFormat(self): try: decoder.decode( ints2octs((16, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)) ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -886,7 +879,7 @@ def testDecodeOpenTypesUnknownType(self): decodeOpenTypes=True ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: @@ -1025,7 +1018,7 @@ def testDecodeOpenTypesUnknownType(self): decodeOpenTypes=True ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: @@ -1172,27 +1165,25 @@ def testWithOptionalAndDefaultedIndefModeChunked(self): ints2octs((49, 128, 5, 0, 36, 128, 4, 4, 113, 117, 105, 99, 4, 4, 107, 32, 98, 114, 4, 3, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)) ) == (self.s, null) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testWithOptionalAndDefaultedDefModeSubst(self): - # assert decoder.decode( - # ints2octs((49, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), str2octs('')) - - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testWithOptionalAndDefaultedIndefModeSubst(self): - # assert decoder.decode( - # ints2octs((49, 128, 5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), - # substrateFun=lambda a, b, c: (b, str2octs('')) - # ) == (ints2octs( - # (5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), str2octs('')) + def testWithOptionalAndDefaultedDefModeSubst(self): + assert decoder.decode( + ints2octs((49, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)), str2octs('')) + + def testWithOptionalAndDefaultedIndefModeSubst(self): + assert decoder.decode( + ints2octs((49, 128, 5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs( + (5, 0, 36, 128, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 0, 0, 2, 1, 1, 0, 0)), str2octs('')) def testTagFormat(self): try: decoder.decode( ints2octs((16, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1)) ) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'wrong tagFormat worked out' @@ -1505,28 +1496,26 @@ def testTaggedImIndefMode(self): s = univ.Any('\004\003fox').subtype(implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 4)) assert decoder.decode(ints2octs((164, 128, 4, 3, 102, 111, 120, 0, 0)), asn1Spec=s) == (s, null) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testByUntaggedSubst(self): - # assert decoder.decode( - # ints2octs((4, 3, 102, 111, 120)), - # asn1Spec=self.s, - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((4, 3, 102, 111, 120)), str2octs('')) + def testByUntaggedSubst(self): + assert decoder.decode( + ints2octs((4, 3, 102, 111, 120)), + asn1Spec=self.s, + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((4, 3, 102, 111, 120)), str2octs('')) - # TODO: Not clear how to deal with substrateFun in stream implementation - # def testTaggedExSubst(self): - # assert decoder.decode( - # ints2octs((164, 5, 4, 3, 102, 111, 120)), - # asn1Spec=self.s, - # substrateFun=lambda a, b, c: (b, b[c:]) - # ) == (ints2octs((164, 5, 4, 3, 102, 111, 120)), str2octs('')) + def testTaggedExSubst(self): + assert decoder.decode( + ints2octs((164, 5, 4, 3, 102, 111, 120)), + asn1Spec=self.s, + substrateFun=lambda a, b, c, d: streaming.read(b, c) + ) == (ints2octs((164, 5, 4, 3, 102, 111, 120)), str2octs('')) class EndOfOctetsTestCase(BaseTestCase): def testUnexpectedEoo(self): try: decoder.decode(ints2octs((0, 0))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'end-of-contents octets accepted at top level' @@ -1539,7 +1528,7 @@ def testExpectedEoo(self): def testDefiniteNoEoo(self): try: decoder.decode(ints2octs((0x23, 0x02, 0x00, 0x00))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'end-of-contents octets accepted inside definite-length encoding' @@ -1551,7 +1540,7 @@ def testIndefiniteEoo(self): def testNoLongFormEoo(self): try: decoder.decode(ints2octs((0x23, 0x80, 0x00, 0x81, 0x00))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'end-of-contents octets accepted with invalid long-form length' @@ -1559,7 +1548,7 @@ def testNoLongFormEoo(self): def testNoConstructedEoo(self): try: decoder.decode(ints2octs((0x23, 0x80, 0x20, 0x00))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'end-of-contents octets accepted with invalid constructed encoding' @@ -1567,7 +1556,7 @@ def testNoConstructedEoo(self): def testNoEooData(self): try: decoder.decode(ints2octs((0x23, 0x80, 0x00, 0x01, 0x00))) - except PyAsn1Error: + except error.PyAsn1Error: pass else: assert 0, 'end-of-contents octets accepted with unexpected data' @@ -1590,41 +1579,50 @@ def setUp(self): self.substrate = ints2octs([48, 18, 5, 0, 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 2, 1, 1]) def testOctetString(self): - s = list(decoder.decodeStream(univ.OctetString(self.substrate), asn1Spec=self.s)) + s = list(decoder.StreamingDecoder( + univ.OctetString(self.substrate), asn1Spec=self.s)) assert [self.s] == s def testAny(self): - s = list(decoder.decodeStream(univ.Any(self.substrate), asn1Spec=self.s)) + s = list(decoder.StreamingDecoder( + univ.Any(self.substrate), asn1Spec=self.s)) assert [self.s] == s class ErrorOnDecodingTestCase(BaseTestCase): def testErrorCondition(self): - decode = decoder.Decoder(decoder.tagMap, decoder.typeMap) - substrate = b'abc' - stream = decoder._asSeekableStream(substrate) + decode = decoder.SingleItemDecoder(decoder.TAG_MAP, decoder.TYPE_MAP) + substrate = ints2octs((00, 1, 2)) + stream = streaming.asSeekableStream(substrate) try: - asn1Object = decode(stream) + asn1Object = next(decode(stream)) - except PyAsn1Error: + except error.PyAsn1Error: exc = sys.exc_info()[1] - assert isinstance(exc, PyAsn1Error), ( + assert isinstance(exc, error.PyAsn1Error), ( 'Unexpected exception raised %r' % (exc,)) else: assert False, 'Unexpected decoder result %r' % (asn1Object,) def testRawDump(self): - decode = decoder.Decoder(decoder.tagMap, decoder.typeMap) substrate = ints2octs((31, 8, 2, 1, 1, 131, 3, 2, 1, 12)) - stream = decoder._asSeekableStream(substrate, ) + stream = streaming.asSeekableStream(substrate) + + class StateMachine(decoder.SingleItemDecoder): + defaultErrorState = decoder.stDumpRawValue - decode.defaultErrorState = decoder.stDumpRawValue + class StreamingDecoder(decoder.StreamingDecoder): + SINGLE_ITEM_DECODER = StateMachine - asn1Object = decode(stream) - rest = stream.read() + class OneShotDecoder(decoder.Decoder): + STREAMING_DECODER = StreamingDecoder + + d = OneShotDecoder() + + asn1Object, rest = d(stream) assert isinstance(asn1Object, univ.Any), ( 'Unexpected raw dump type %r' % (asn1Object,)) @@ -1643,7 +1641,7 @@ def testOneObject(self): out.write(ints2octs((2, 1, 12))) with open(path, "rb") as source: - values = list(decoder.decodeStream(source)) + values = list(decoder.StreamingDecoder(source)) assert values == [12] finally: @@ -1656,9 +1654,10 @@ def testMoreObjects(self): out.write(ints2octs((2, 1, 12, 35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0))) with open(path, "rb") as source: - values = list(decoder.decodeStream(source)) + values = list(decoder.StreamingDecoder(source)) assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)] + finally: os.remove(path) @@ -1669,8 +1668,11 @@ def testInvalidFileContent(self): out.write(ints2octs((2, 1, 12, 35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0, 7))) with open(path, "rb") as source: - with self.assertRaises(SubstrateUnderrunError): - _ = list(decoder.decodeStream(source)) + list(decoder.StreamingDecoder(source)) + + except error.EndOfStreamError: + pass + finally: os.remove(path) @@ -1679,7 +1681,7 @@ class BytesIOTestCase(BaseTestCase): def testRead(self): source = ints2octs((2, 1, 12, 35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0)) stream = io.BytesIO(source) - values = list(decoder.decodeStream(stream)) + values = list(decoder.StreamingDecoder(stream)) assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)] @@ -1687,8 +1689,114 @@ class UnicodeTestCase(BaseTestCase): def testFail(self): # This ensures that unicode objects in Python 2 & str objects in Python 3.7 cannot be parsed. source = ints2octs((2, 1, 12, 35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0)).decode("latin-1") - with self.assertRaises(UnsupportedSubstrateError): - _ = next(decoder.decodeStream(source)) + try: + next(decoder.StreamingDecoder(source)) + + except error.UnsupportedSubstrateError: + pass + + else: + assert False, 'Tolerated parsing broken unicode strings' + + +class RestartableDecoderTestCase(BaseTestCase): + + class NonBlockingStream(io.BytesIO): + block = False + + def read(self, size=-1): + self.block = not self.block + if self.block: + return # this is what non-blocking streams sometimes do + + return io.BytesIO.read(self, size) + + def setUp(self): + BaseTestCase.setUp(self) + + self.s = univ.SequenceOf(componentType=univ.OctetString()) + self.s.setComponentByPosition(0, univ.OctetString('quick brown')) + source = ints2octs( + (48, 26, + 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, + 4, 11, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110)) + self.stream = self.NonBlockingStream(source) + + def testPartialReadingFromNonBlockingStream(self): + iterator = iter(decoder.StreamingDecoder(self.stream, asn1Spec=self.s)) + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' not in res.context + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' not in res.context + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' in res.context + assert isinstance(res.context['asn1Object'], univ.SequenceOf) + assert res.context['asn1Object'].isValue + assert len(res.context['asn1Object']) == 0 + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' in res.context + assert isinstance(res.context['asn1Object'], univ.SequenceOf) + assert res.context['asn1Object'].isValue + assert len(res.context['asn1Object']) == 0 + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' in res.context + assert isinstance(res.context['asn1Object'], univ.SequenceOf) + assert res.context['asn1Object'].isValue + assert len(res.context['asn1Object']) == 0 + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' in res.context + assert isinstance(res.context['asn1Object'], univ.SequenceOf) + assert res.context['asn1Object'].isValue + assert len(res.context['asn1Object']) == 1 + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' in res.context + assert isinstance(res.context['asn1Object'], univ.SequenceOf) + assert res.context['asn1Object'].isValue + assert len(res.context['asn1Object']) == 1 + + res = next(iterator) + + assert isinstance(res, error.SubstrateUnderrunError) + assert 'asn1Object' in res.context + assert isinstance(res.context['asn1Object'], univ.SequenceOf) + assert res.context['asn1Object'].isValue + assert len(res.context['asn1Object']) == 1 + + res = next(iterator) + + assert isinstance(res, univ.SequenceOf) + assert res.isValue + assert len(res) == 2 + + try: + next(iterator) + + except StopIteration: + pass + + else: + assert False, 'End of stream not raised' class CompressedFilesTestCase(BaseTestCase): @@ -1699,9 +1807,10 @@ def testGzip(self): out.write(ints2octs((2, 1, 12, 35, 128, 3, 2, 0, 169, 3, 2, 1, 138, 0, 0))) with gzip.open(path, "rb") as source: - values = list(decoder.decodeStream(source)) + values = list(decoder.StreamingDecoder(source)) assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)] + finally: os.remove(path) @@ -1715,7 +1824,7 @@ def testZipfile(self): with zipfile.ZipFile(path, "r") as myzip: with myzip.open("data", "r") as source: - values = list(decoder.decodeStream(source)) + values = list(decoder.StreamingDecoder(source)) assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)] finally: os.remove(path) @@ -1729,63 +1838,12 @@ def testZipfileMany(self): with zipfile.ZipFile(path, "r") as myzip: with myzip.open("data", "r") as source: - values = list(decoder.decodeStream(source)) + values = list(decoder.StreamingDecoder(source)) assert values == [12, (1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1)] * 1000 finally: os.remove(path) -class CachingStreamWrapperTestCase(BaseTestCase): - def setUp(self): - self.shortText = b"abcdefghij" - self.longText = self.shortText * (io.DEFAULT_BUFFER_SIZE * 5) - self.shortStream = io.BytesIO(self.shortText) - self.longStream = io.BytesIO(self.longText) - - def testReadJustFromCache(self): - wrapper = decoder._CachingStreamWrapper(self.shortStream) - wrapper.read(6) - wrapper.seek(3) - assert wrapper.read(1) == b"d" - assert wrapper.read(1) == b"e" - assert wrapper.tell() == 5 - - def testReadFromCacheAndStream(self): - wrapper = decoder._CachingStreamWrapper(self.shortStream) - wrapper.read(6) - wrapper.seek(3) - assert wrapper.read(4) == b"defg" - assert wrapper.tell() == 7 - - def testReadJustFromStream(self): - wrapper = decoder._CachingStreamWrapper(self.shortStream) - assert wrapper.read(6) == b"abcdef" - assert wrapper.tell() == 6 - - def testPeek(self): - wrapper = decoder._CachingStreamWrapper(self.longStream) - read_bytes = wrapper.peek(io.DEFAULT_BUFFER_SIZE + 73) - assert len(read_bytes) == io.DEFAULT_BUFFER_SIZE + 73 - assert read_bytes.startswith(b"abcdefg") - assert wrapper.tell() == 0 - assert wrapper.read(4) == b"abcd" - - def testMarkedPositionResets(self): - wrapper = decoder._CachingStreamWrapper(self.longStream) - wrapper.read(10) - wrapper._markedPosition = wrapper.tell() - assert wrapper._markedPosition == 10 - - # Reach the maximum capacity of cache - wrapper.read(io.DEFAULT_BUFFER_SIZE) - assert wrapper.tell() == 10 + io.DEFAULT_BUFFER_SIZE - - # The following should clear the cache - wrapper._markedPosition = wrapper.tell() - assert wrapper._markedPosition == 0 - assert len(wrapper._cache.getvalue()) == 0 - - suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) if __name__ == '__main__': diff --git a/tests/codec/ber/test_encoder.py b/tests/codec/ber/test_encoder.py index df82e7b4..b8802853 100644 --- a/tests/codec/ber/test_encoder.py +++ b/tests/codec/ber/test_encoder.py @@ -382,19 +382,19 @@ def testBin2(self): def testBin3(self): # change binEncBase in the RealEncoder instance => for all further Real - binEncBase, encoder.typeMap[univ.Real.typeId].binEncBase = encoder.typeMap[univ.Real.typeId].binEncBase, 16 + binEncBase, encoder.TYPE_MAP[univ.Real.typeId].binEncBase = encoder.TYPE_MAP[univ.Real.typeId].binEncBase, 16 assert encoder.encode( univ.Real((0.00390625, 2, 0)) # check encbase = 16 ) == ints2octs((9, 3, 160, 254, 1)) - encoder.typeMap[univ.Real.typeId].binEncBase = binEncBase + encoder.TYPE_MAP[univ.Real.typeId].binEncBase = binEncBase def testBin4(self): # choose binEncBase automatically for all further Real (testBin[4-7]) - binEncBase, encoder.typeMap[univ.Real.typeId].binEncBase = encoder.typeMap[univ.Real.typeId].binEncBase, None + binEncBase, encoder.TYPE_MAP[univ.Real.typeId].binEncBase = encoder.TYPE_MAP[univ.Real.typeId].binEncBase, None assert encoder.encode( univ.Real((1, 2, 0)) # check exponent = 0 ) == ints2octs((9, 3, 128, 0, 1)) - encoder.typeMap[univ.Real.typeId].binEncBase = binEncBase + encoder.TYPE_MAP[univ.Real.typeId].binEncBase = binEncBase def testBin5(self): assert encoder.encode( diff --git a/tests/codec/cer/test_decoder.py b/tests/codec/cer/test_decoder.py index bb5ce93b..d628061b 100644 --- a/tests/codec/cer/test_decoder.py +++ b/tests/codec/cer/test_decoder.py @@ -41,6 +41,7 @@ def testOverflow(self): except PyAsn1Error: pass + class BitStringDecoderTestCase(BaseTestCase): def testShortMode(self): assert decoder.decode( diff --git a/tests/codec/cer/test_encoder.py b/tests/codec/cer/test_encoder.py index e155571b..ce263878 100644 --- a/tests/codec/cer/test_encoder.py +++ b/tests/codec/cer/test_encoder.py @@ -84,7 +84,6 @@ def testMissingTimezone(self): else: assert 0, 'Missing timezone tolerated' - def testDecimalCommaPoint(self): try: assert encoder.encode( diff --git a/tests/codec/test_streaming.py b/tests/codec/test_streaming.py new file mode 100644 index 00000000..c608b111 --- /dev/null +++ b/tests/codec/test_streaming.py @@ -0,0 +1,75 @@ +# +# This file is part of pyasn1 software. +# +# Copyright (c) 2005-2019, Ilya Etingof +# License: http://snmplabs.com/pyasn1/license.html +# +import io +import sys + +try: + import unittest2 as unittest + +except ImportError: + import unittest + +from tests.base import BaseTestCase + +from pyasn1.codec import streaming + + +class CachingStreamWrapperTestCase(BaseTestCase): + def setUp(self): + self.shortText = b"abcdefghij" + self.longText = self.shortText * (io.DEFAULT_BUFFER_SIZE * 5) + self.shortStream = io.BytesIO(self.shortText) + self.longStream = io.BytesIO(self.longText) + + def testReadJustFromCache(self): + wrapper = streaming.CachingStreamWrapper(self.shortStream) + wrapper.read(6) + wrapper.seek(3) + assert wrapper.read(1) == b"d" + assert wrapper.read(1) == b"e" + assert wrapper.tell() == 5 + + def testReadFromCacheAndStream(self): + wrapper = streaming.CachingStreamWrapper(self.shortStream) + wrapper.read(6) + wrapper.seek(3) + assert wrapper.read(4) == b"defg" + assert wrapper.tell() == 7 + + def testReadJustFromStream(self): + wrapper = streaming.CachingStreamWrapper(self.shortStream) + assert wrapper.read(6) == b"abcdef" + assert wrapper.tell() == 6 + + def testPeek(self): + wrapper = streaming.CachingStreamWrapper(self.longStream) + read_bytes = wrapper.peek(io.DEFAULT_BUFFER_SIZE + 73) + assert len(read_bytes) == io.DEFAULT_BUFFER_SIZE + 73 + assert read_bytes.startswith(b"abcdefg") + assert wrapper.tell() == 0 + assert wrapper.read(4) == b"abcd" + + def testMarkedPositionResets(self): + wrapper = streaming.CachingStreamWrapper(self.longStream) + wrapper.read(10) + wrapper.markedPosition = wrapper.tell() + assert wrapper.markedPosition == 10 + + # Reach the maximum capacity of cache + wrapper.read(io.DEFAULT_BUFFER_SIZE) + assert wrapper.tell() == 10 + io.DEFAULT_BUFFER_SIZE + + # The following should clear the cache + wrapper.markedPosition = wrapper.tell() + assert wrapper.markedPosition == 0 + assert len(wrapper._cache.getvalue()) == 0 + + +suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) + +if __name__ == '__main__': + unittest.TextTestRunner(verbosity=2).run(suite)