diff --git a/stix2/core.py b/stix2/core.py index 0d1fee5c..b03e3d74 100644 --- a/stix2/core.py +++ b/stix2/core.py @@ -7,10 +7,10 @@ import stix2 -from .base import _STIXBase +from .base import _Observable, _STIXBase from .exceptions import ParseError from .markings import _MarkingsMixin -from .utils import _get_dict +from .utils import SCO21_EXT_REGEX, TYPE_REGEX, _get_dict STIX2_OBJ_MAPS = {} @@ -258,22 +258,54 @@ def _register_observable(new_observable, version=None): OBJ_MAP_OBSERVABLE[new_observable._type] = new_observable -def _register_observable_extension(observable, new_extension, version=None): +def _register_observable_extension( + observable, new_extension, version=stix2.DEFAULT_VERSION, +): """Register a custom extension to a STIX Cyber Observable type. Args: - observable: An observable object + observable: An observable class or instance new_extension (class): A class to register in the Observables Extensions map. - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. + version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). + Defaults to the latest supported version. """ - if version: - v = 'v' + version.replace('.', '') - else: - # Use default version (latest) if no version was provided. - v = 'v' + stix2.DEFAULT_VERSION.replace('.', '') + obs_class = observable if isinstance(observable, type) else \ + type(observable) + ext_type = new_extension._type + + if not issubclass(obs_class, _Observable): + raise ValueError("'observable' must be a valid Observable class!") + + if version == "2.0": + if not re.match(TYPE_REGEX, ext_type): + raise ValueError( + "Invalid extension type name '%s': must only contain the " + "characters a-z (lowercase ASCII), 0-9, and hyphen (-)." % + ext_type, + ) + else: # 2.1+ + if not re.match(SCO21_EXT_REGEX, ext_type): + raise ValueError( + "Invalid extension type name '%s': must only contain the " + "characters a-z (lowercase ASCII), 0-9, hyphen (-), and end " + "with '-ext'." % ext_type, + ) + + if len(ext_type) < 3 or len(ext_type) > 250: + raise ValueError( + "Invalid extension type name '%s': must be between 3 and 250" + " characters." % ext_type, + ) + + if not new_extension._properties: + raise ValueError( + "Invalid extension: must define at least one property: " + + ext_type, + ) + + v = 'v' + version.replace('.', '') try: observable_type = observable._type @@ -287,7 +319,7 @@ def _register_observable_extension(observable, new_extension, version=None): EXT_MAP = STIX2_OBJ_MAPS[v]['observable-extensions'] try: - EXT_MAP[observable_type][new_extension._type] = new_extension + EXT_MAP[observable_type][ext_type] = new_extension except KeyError: if observable_type not in OBJ_MAP_OBSERVABLE: raise ValueError( @@ -296,7 +328,7 @@ def _register_observable_extension(observable, new_extension, version=None): % observable_type, ) else: - EXT_MAP[observable_type] = {new_extension._type: new_extension} + EXT_MAP[observable_type] = {ext_type: new_extension} def _collect_stix2_mappings(): diff --git a/stix2/custom.py b/stix2/custom.py index 802fd070..f3c89cf9 100644 --- a/stix2/custom.py +++ b/stix2/custom.py @@ -1,6 +1,8 @@ from collections import OrderedDict import re +import six + from .base import _cls_init, _Extension, _Observable, _STIXBase from .core import ( STIXDomainObject, _register_marking, _register_object, @@ -113,24 +115,23 @@ def __init__(self, **kwargs): def _custom_extension_builder(cls, observable, type, properties, version): - if not observable or not issubclass(observable, _Observable): - raise ValueError("'observable' must be a valid Observable class!") - - class _CustomExtension(cls, _Extension): - if not re.match(TYPE_REGEX, type): - raise ValueError( - "Invalid extension type name '%s': must only contain the " - "characters a-z (lowercase ASCII), 0-9, and hyphen (-)." % type, - ) - elif len(type) < 3 or len(type) > 250: - raise ValueError("Invalid extension type name '%s': must be between 3 and 250 characters." % type) + try: + prop_dict = OrderedDict(properties) + except TypeError as e: + six.raise_from( + ValueError( + "Extension properties must be dict-like, e.g. a list " + "containing tuples. For example, " + "[('property1', IntegerProperty())]", + ), + e, + ) - if not properties or not isinstance(properties, list): - raise ValueError("Must supply a list, containing tuples. For example, [('property1', IntegerProperty())]") + class _CustomExtension(cls, _Extension): _type = type - _properties = OrderedDict(properties) + _properties = prop_dict def __init__(self, **kwargs): _Extension.__init__(self, **kwargs) diff --git a/stix2/test/v20/test_custom.py b/stix2/test/v20/test_custom.py index ce1aac3e..b986777d 100644 --- a/stix2/test/v20/test_custom.py +++ b/stix2/test/v20/test_custom.py @@ -821,27 +821,24 @@ class BlaExtension(): def test_custom_extension_no_properties(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): @stix2.v20.CustomExtension(stix2.v20.DomainName, 'x-new-ext2', None) class BarExtension(): pass - assert "Must supply a list, containing tuples." in str(excinfo.value) def test_custom_extension_empty_properties(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): @stix2.v20.CustomExtension(stix2.v20.DomainName, 'x-new-ext2', []) class BarExtension(): pass - assert "Must supply a list, containing tuples." in str(excinfo.value) def test_custom_extension_dict_properties(): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): @stix2.v20.CustomExtension(stix2.v20.DomainName, 'x-new-ext2', {}) class BarExtension(): pass - assert "Must supply a list, containing tuples." in str(excinfo.value) def test_custom_extension_no_init_1(): diff --git a/stix2/test/v21/test_custom.py b/stix2/test/v21/test_custom.py index b46288dc..1e6f6297 100644 --- a/stix2/test/v21/test_custom.py +++ b/stix2/test/v21/test_custom.py @@ -800,7 +800,7 @@ def test_custom_extension_wrong_observable_type(): ) def test_custom_extension_with_list_and_dict_properties_observable_type(data): @stix2.v21.CustomExtension( - stix2.v21.UserAccount, 'some-extension', [ + stix2.v21.UserAccount, 'x-some-extension-ext', [ ('keys', stix2.properties.ListProperty(stix2.properties.DictionaryProperty, required=True)), ], ) @@ -876,32 +876,29 @@ class BlaExtension(): def test_custom_extension_no_properties(): - with pytest.raises(ValueError) as excinfo: - @stix2.v21.CustomExtension(stix2.v21.DomainName, 'x-new-ext2', None) + with pytest.raises(ValueError): + @stix2.v21.CustomExtension(stix2.v21.DomainName, 'x-new2-ext', None) class BarExtension(): pass - assert "Must supply a list, containing tuples." in str(excinfo.value) def test_custom_extension_empty_properties(): - with pytest.raises(ValueError) as excinfo: - @stix2.v21.CustomExtension(stix2.v21.DomainName, 'x-new-ext2', []) + with pytest.raises(ValueError): + @stix2.v21.CustomExtension(stix2.v21.DomainName, 'x-new2-ext', []) class BarExtension(): pass - assert "Must supply a list, containing tuples." in str(excinfo.value) def test_custom_extension_dict_properties(): - with pytest.raises(ValueError) as excinfo: - @stix2.v21.CustomExtension(stix2.v21.DomainName, 'x-new-ext2', {}) + with pytest.raises(ValueError): + @stix2.v21.CustomExtension(stix2.v21.DomainName, 'x-new2-ext', {}) class BarExtension(): pass - assert "Must supply a list, containing tuples." in str(excinfo.value) def test_custom_extension_no_init_1(): @stix2.v21.CustomExtension( - stix2.v21.DomainName, 'x-new-extension', [ + stix2.v21.DomainName, 'x-new-extension-ext', [ ('property1', stix2.properties.StringProperty(required=True)), ], ) @@ -914,7 +911,7 @@ class NewExt(): def test_custom_extension_no_init_2(): @stix2.v21.CustomExtension( - stix2.v21.DomainName, 'x-new-ext2', [ + stix2.v21.DomainName, 'x-new2-ext', [ ('property1', stix2.properties.StringProperty(required=True)), ], ) @@ -949,14 +946,14 @@ def test_custom_and_spec_extension_mix(): file_obs = stix2.v21.File( name="my_file.dat", extensions={ - "x-custom1": { + "x-custom1-ext": { "a": 1, "b": 2, }, "ntfs-ext": { "sid": "S-1-whatever", }, - "x-custom2": { + "x-custom2-ext": { "z": 99.9, "y": False, }, @@ -969,8 +966,8 @@ def test_custom_and_spec_extension_mix(): allow_custom=True, ) - assert file_obs.extensions["x-custom1"] == {"a": 1, "b": 2} - assert file_obs.extensions["x-custom2"] == {"y": False, "z": 99.9} + assert file_obs.extensions["x-custom1-ext"] == {"a": 1, "b": 2} + assert file_obs.extensions["x-custom2-ext"] == {"y": False, "z": 99.9} assert file_obs.extensions["ntfs-ext"].sid == "S-1-whatever" assert file_obs.extensions["raster-image-ext"].image_height == 1024 diff --git a/stix2/utils.py b/stix2/utils.py index b23b0e4d..7b3b6cf2 100644 --- a/stix2/utils.py +++ b/stix2/utils.py @@ -26,6 +26,7 @@ STIX_UNMOD_PROPERTIES = ['created', 'created_by_ref', 'id', 'type'] TYPE_REGEX = r'^\-?[a-z0-9]+(-[a-z0-9]+)*\-?$' +SCO21_EXT_REGEX = r'^\-?[a-z0-9]+(-[a-z0-9]+)*\-ext$' class STIXdatetime(dt.datetime):