From c0e83b897185a1589e71a3009418a77dd34caa68 Mon Sep 17 00:00:00 2001 From: Rich Piazza Date: Tue, 2 Apr 2024 17:58:06 -0400 Subject: [PATCH] Dictionary Tables use valid_types --- stix2/datastore/relational_db/add_method.py | 7 +- .../datastore/relational_db/table_creation.py | 65 ++++++++++++------- stix2/datastore/relational_db/utils.py | 44 +++++++++++++ stix2/properties.py | 5 +- 4 files changed, 95 insertions(+), 26 deletions(-) diff --git a/stix2/datastore/relational_db/add_method.py b/stix2/datastore/relational_db/add_method.py index ad08a2e0..02d10c80 100644 --- a/stix2/datastore/relational_db/add_method.py +++ b/stix2/datastore/relational_db/add_method.py @@ -1,3 +1,5 @@ +import re + # _ALLOWABLE_CLASSES = get_all_subclasses(_STIXBase21) # # @@ -7,9 +9,8 @@ def create_real_method_name(name, klass_name): # if klass_name not in _ALLOWABLE_CLASSES: # raise NameError - # split_up_klass_name = re.findall('[A-Z][^A-Z]*', klass_name) - # split_up_klass_name.remove("Type") - return name + "_" + "_".join([x.lower() for x in klass_name]) + split_up_klass_name = re.findall('[A-Z][^A-Z]*', klass_name) + return name + "_" + "_".join([x.lower() for x in split_up_klass_name]) def add_method(cls): diff --git a/stix2/datastore/relational_db/table_creation.py b/stix2/datastore/relational_db/table_creation.py index 1176a51f..0b8b7446 100644 --- a/stix2/datastore/relational_db/table_creation.py +++ b/stix2/datastore/relational_db/table_creation.py @@ -8,7 +8,8 @@ from stix2.datastore.relational_db.add_method import add_method from stix2.datastore.relational_db.utils import ( SCO_COMMON_PROPERTIES, SDO_COMMON_PROPERTIES, canonicalize_table_name, - flat_classes, get_stix_object_classes, schema_for, + determine_column_name, determine_sql_type_from_class, flat_classes, + get_stix_object_classes, schema_for, ) from stix2.properties import ( BinaryProperty, BooleanProperty, DictionaryProperty, @@ -193,19 +194,24 @@ def create_core_table(metadata, schema_name): ) +@add_method(Property) +def determine_sql_type(self): # noqa: F811 + pass + + @add_method(KillChainPhase) -def determine_sql_type(self): +def determine_sql_type(self): # noqa: F811 return None -@add_method(BooleanProperty) +@add_method(BinaryProperty) def determine_sql_type(self): # noqa: F811 return Boolean -@add_method(Property) +@add_method(BooleanProperty) def determine_sql_type(self): # noqa: F811 - pass + return Boolean @add_method(FloatProperty) @@ -213,15 +219,31 @@ def determine_sql_type(self): # noqa: F811 return Float +@add_method(HexProperty) +def determine_sql_type(self): # noqa: F811 + return LargeBinary + + @add_method(IntegerProperty) def determine_sql_type(self): # noqa: F811 return Integer +@add_method(ReferenceProperty) +def determine_sql_type(self): # noqa: F811 + return Text + + @add_method(StringProperty) def determine_sql_type(self): # noqa: F811 return Text + +@add_method(TimestampProperty) +def determine_sql_type(self): # noqa: F811 + return TIMESTAMP(timezone=True) + + # ----------------------------- generate_table_information methods ---------------------------- @@ -264,36 +286,35 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is nullable=False, ), ) - if len(self.specifics) == 1: - if self.specifics[0] != "string_list": + if len(self.valid_types) == 1: + if not isinstance(self.valid_types[0], ListProperty): columns.append( Column( "value", - Text if self.specifics[0] == "string" else Integer, + # its a class + determine_sql_type_from_class(self.valid_types[0]), nullable=False, ), ) else: + contained_class = self.valid_types[0].contained columns.append( Column( "value", - ARRAY(Text), + # its an instance, not a class + ARRAY(contained_class.determine_sql_type()), nullable=False, ), ) else: - columns.append( - Column( - "string_value", - Text, - ), - ) - columns.append( - Column( - "integer_value", - Integer, - ), - ) + for column_type in self.valid_types: + sql_type = determine_sql_type_from_class(column_type) + columns.append( + Column( + determine_column_name(column_type), + sql_type, + ), + ) return [Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name)] @@ -315,7 +336,7 @@ def generate_table_information(self, name, **kwargs): # noqa: F811 CheckConstraint( f"{name} ~ '^{enum_re}$'", ), - nullable=not (self.required), + nullable=not self.required, ) diff --git a/stix2/datastore/relational_db/utils.py b/stix2/datastore/relational_db/utils.py index bee8d54d..0958958f 100644 --- a/stix2/datastore/relational_db/utils.py +++ b/stix2/datastore/relational_db/utils.py @@ -1,7 +1,15 @@ from collections.abc import Iterable, Mapping import inflection +from sqlalchemy import ( # create_engine,; insert, + TIMESTAMP, Boolean, Float, Integer, LargeBinary, Text, +) +from stix2.properties import ( + BinaryProperty, BooleanProperty, FloatProperty, HexProperty, + IntegerProperty, Property, ReferenceProperty, StringProperty, + TimestampProperty, +) from stix2.v21.base import ( _DomainObject, _Extension, _MetaObject, _Observable, _RelationshipObject, ) @@ -112,3 +120,39 @@ def flat_classes(class_or_classes): yield from flat_classes(class_) else: yield class_or_classes + + +def determine_sql_type_from_class(cls): # noqa: F811 + if cls == BinaryProperty: + return LargeBinary + elif cls == BooleanProperty: + return Boolean + elif cls == FloatProperty: + return Float + elif cls == HexProperty: + return LargeBinary + elif cls == IntegerProperty: + return Integer + elif cls == StringProperty or cls == ReferenceProperty: + return Text + elif cls == TimestampProperty: + return TIMESTAMP(timezone=True) + elif cls == Property: + return Text + + +def determine_column_name(cls): # noqa: F811 + if cls == BinaryProperty: + return "binary_value" + elif cls == BooleanProperty: + return "boolean_value" + elif cls == FloatProperty: + return "float_value" + elif cls == HexProperty: + return "hex_value" + elif cls == IntegerProperty: + return "integer_value" + elif cls == StringProperty or cls == ReferenceProperty: + return "string_value" + elif cls == TimestampProperty: + return "timestamp_value" diff --git a/stix2/properties.py b/stix2/properties.py index a733c042..a535cd02 100644 --- a/stix2/properties.py +++ b/stix2/properties.py @@ -404,7 +404,10 @@ class DictionaryProperty(Property): def __init__(self, valid_types=None, spec_version=DEFAULT_VERSION, **kwargs): self.spec_version = spec_version - simple_types = [BinaryProperty, BooleanProperty, FloatProperty, HexProperty, IntegerProperty, StringProperty, TimestampProperty, ReferenceProperty] + simple_types = [ + BinaryProperty, BooleanProperty, FloatProperty, HexProperty, IntegerProperty, StringProperty, + TimestampProperty, ReferenceProperty, + ] if not valid_types: valid_types = [Property] else: