Skip to content

Commit

Permalink
make some functions internal, add some docs for them
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanvg committed Feb 16, 2021
1 parent 09fd8c0 commit d2d85ba
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
14 changes: 8 additions & 6 deletions stix2/equivalence/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import logging

from ..object import (
WEIGHTS, bucket_per_type, exact_match, list_reference_check, object_pairs,
object_similarity, partial_string_based, partial_timestamp_based,
reference_check,
WEIGHTS, _bucket_per_type, _object_pairs, exact_match,
list_reference_check, object_similarity, partial_string_based,
partial_timestamp_based, reference_check,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,9 +99,11 @@ def graph_similarity(ds1, ds2, prop_scores={}, **weight_dict):
raise ValueError("weight_dict['_internal']['max_depth'] must be greater than 0")
depth = weights["_internal"]["max_depth"]

graph1 = bucket_per_type(ds1.query([]))
graph2 = bucket_per_type(ds2.query([]))
pairs = object_pairs(graph1, graph2, weights)
pairs = _object_pairs(
_bucket_per_type(ds1.query([])),
_bucket_per_type(ds2.query([])),
weights,
)

for object1, object2 in pairs:
iprop_score1 = {}
Expand Down
16 changes: 11 additions & 5 deletions stix2/equivalence/object/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,9 @@ def list_reference_check(refs1, refs2, ds1, ds2, **weights):
weighted on the amount of unique objects that could 1) be de-referenced 2) """
results = {}

pairs = object_pairs(
bucket_per_type(refs1, "id-split"),
bucket_per_type(refs2, "id-split"),
pairs = _object_pairs(
_bucket_per_type(refs1, "id-split"),
_bucket_per_type(refs2, "id-split"),
weights,
)

Expand Down Expand Up @@ -433,7 +433,10 @@ def list_reference_check(refs1, refs2, ds1, ds2, **weights):
return result


def bucket_per_type(g, mode="type"):
def _bucket_per_type(g, mode="type"):
"""Given a list of objects or references, bucket them by type.
Depending on the list type: extract from 'type' property or using
the 'id'"""
buckets = collections.defaultdict(list)
if mode == "type":
[buckets[obj["type"]].append(obj) for obj in g]
Expand All @@ -442,7 +445,10 @@ def bucket_per_type(g, mode="type"):
return buckets


def object_pairs(g1, g2, w):
def _object_pairs(g1, g2, w):
"""Returns a generator with the product of the comparable
objects for the graph similarity process. It determines
objects in common between graphs and objects with weights."""
types_in_common = set(g1.keys()).intersection(g2.keys())
testable_types = types_in_common.intersection(w.keys())

Expand Down

0 comments on commit d2d85ba

Please sign in to comment.