Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the memory store to support the new top-level 2.1 SCOs. #342

Merged
merged 2 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions stix2/datastore/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from stix2.core import parse
from stix2.datastore import DataSink, DataSource, DataStoreMixin
from stix2.datastore.filters import FilterSet, apply_common_filters
from stix2.utils import is_marking


def _add(store, stix_data, allow_custom=True, version=None):
Expand Down Expand Up @@ -47,12 +46,10 @@ def _add(store, stix_data, allow_custom=True, version=None):
else:
stix_obj = parse(stix_data, allow_custom, version)

# Map ID directly to the object, if it is a marking. Otherwise,
# map to a family, so we can track multiple versions.
if is_marking(stix_obj):
store._data[stix_obj["id"]] = stix_obj

else:
# Map ID to a _ObjectFamily if the object is versioned, so we can track
# multiple versions. Otherwise, map directly to the object. All
# versioned objects should have a "modified" property.
if "modified" in stix_obj:
if stix_obj["id"] in store._data:
obj_family = store._data[stix_obj["id"]]
else:
Expand All @@ -61,6 +58,9 @@ def _add(store, stix_data, allow_custom=True, version=None):

obj_family.add(stix_obj)

else:
store._data[stix_obj["id"]] = stix_obj


class _ObjectFamily(object):
"""
Expand Down Expand Up @@ -267,12 +267,12 @@ def get(self, stix_id, _composite_filters=None):
"""
stix_obj = None

if is_marking(stix_id):
stix_obj = self._data.get(stix_id)
else:
object_family = self._data.get(stix_id)
if object_family:
stix_obj = object_family.latest_version
mapped_value = self._data.get(stix_id)
if mapped_value:
if isinstance(mapped_value, _ObjectFamily):
stix_obj = mapped_value.latest_version
else:
stix_obj = mapped_value

if stix_obj:
all_filters = list(
Expand Down Expand Up @@ -300,17 +300,13 @@ def all_versions(self, stix_id, _composite_filters=None):

"""
results = []
stix_objs_to_filter = None
if is_marking(stix_id):
stix_obj = self._data.get(stix_id)
if stix_obj:
stix_objs_to_filter = [stix_obj]
else:
object_family = self._data.get(stix_id)
if object_family:
stix_objs_to_filter = object_family.all_versions.values()
mapped_value = self._data.get(stix_id)
if mapped_value:
if isinstance(mapped_value, _ObjectFamily):
stix_objs_to_filter = mapped_value.all_versions.values()
else:
stix_objs_to_filter = [mapped_value]

if stix_objs_to_filter:
all_filters = list(
itertools.chain(
_composite_filters or [],
Expand Down
21 changes: 21 additions & 0 deletions stix2/test/v20/test_datastore_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,24 @@ def test_object_family_internal_components(mem_source):

assert "latest=2017-01-27 13:49:53.936000+00:00>>" in str_representation
assert "latest=2017-01-27 13:49:53.936000+00:00>>" in repr_representation


def test_unversioned_objects(mem_store):
marking = {
"type": "marking-definition",
"id": "marking-definition--48e83cde-e902-4404-85b3-6e81f75ccb62",
"created": "1988-01-02T16:44:04.000Z",
"definition_type": "statement",
"definition": {
"statement": "Copyright (C) ACME Corp.",
},
}

mem_store.add(marking)

obj = mem_store.get(marking["id"])
assert obj["id"] == marking["id"]

objs = mem_store.all_versions(marking["id"])
assert len(objs) == 1
assert objs[0]["id"] == marking["id"]
35 changes: 35 additions & 0 deletions stix2/test/v21/test_datastore_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,38 @@ def test_object_family_internal_components(mem_source):

assert "latest=2017-01-27 13:49:53.936000+00:00>>" in str_representation
assert "latest=2017-01-27 13:49:53.936000+00:00>>" in repr_representation


def test_unversioned_objects(mem_store):
marking = {
"type": "marking-definition",
"spec_version": "2.1",
"id": "marking-definition--48e83cde-e902-4404-85b3-6e81f75ccb62",
"created": "1988-01-02T16:44:04.000Z",
"definition_type": "statement",
"definition": {
"statement": "Copyright (C) ACME Corp.",
},
}

file_sco = {
"type": "file",
"id": "file--bbd59c0c-1aa4-44f1-96de-80b8325372c7",
"name": "cats.png",
}

mem_store.add([marking, file_sco])

obj = mem_store.get(marking["id"])
assert obj["id"] == marking["id"]

obj = mem_store.get(file_sco["id"])
assert obj["id"] == file_sco["id"]

objs = mem_store.all_versions(marking["id"])
assert len(objs) == 1
assert objs[0]["id"] == marking["id"]

objs = mem_store.all_versions(file_sco["id"])
assert len(objs) == 1
assert objs[0]["id"] == file_sco["id"]