From 8fd7226f27f04d61190b366676a237b31385475f Mon Sep 17 00:00:00 2001 From: Rich Piazza Date: Thu, 7 Nov 2024 13:14:52 -0500 Subject: [PATCH] handle ARRAY --- .../database_backend_base.py | 64 ++- .../database_backends/postgres_backend.py | 30 +- .../datastore/relational_db/relational_db.py | 6 +- .../relational_db/relational_db_testing.py | 120 +++--- .../datastore/relational_db/table_creation.py | 386 ++++++++++-------- stix2/datastore/relational_db/utils.py | 24 +- 6 files changed, 374 insertions(+), 256 deletions(-) diff --git a/stix2/datastore/relational_db/database_backends/database_backend_base.py b/stix2/datastore/relational_db/database_backends/database_backend_base.py index f4e791e8..5b65119b 100644 --- a/stix2/datastore/relational_db/database_backends/database_backend_base.py +++ b/stix2/datastore/relational_db/database_backends/database_backend_base.py @@ -4,7 +4,10 @@ from sqlalchemy import create_engine from sqlalchemy_utils import create_database, database_exists, drop_database - +from sqlalchemy import ( # create_engine,; insert, + ARRAY, TIMESTAMP, Boolean, CheckConstraint, Column, Float, ForeignKey, + Integer, LargeBinary, Table, Text, UniqueConstraint, +) class DatabaseBackend: def __init__(self, database_connection_url, force_recreate=False, **kwargs: Any): @@ -20,7 +23,7 @@ def __init__(self, database_connection_url, force_recreate=False, **kwargs: Any) self.database_connection = create_engine(database_connection_url) def _create_schemas(self): - pass + return @staticmethod def _determine_schema_name(stix_object): @@ -32,4 +35,61 @@ def _create_database(self): create_database(self.database_connection.url) self.database_exists = database_exists(self.database_connection.url) + def schema_for(self, stix_class): + return "" + + @staticmethod + def determine_sql_type_for_property(): # noqa: F811 + pass + + @staticmethod + def determine_sql_type_for_kill_chain_phase(): # noqa: F811 + return None + + @staticmethod + def determine_sql_type_for_binary_property(): # noqa: F811 + return Text + + @staticmethod + def determine_sql_type_for_boolean_property(): # noqa: F811 + return Boolean + + @staticmethod + def determine_sql_type_for_float_property(): # noqa: F811 + return Float + + @staticmethod + def determine_sql_type_for_hex_property(): # noqa: F811 + return LargeBinary + + @staticmethod + def determine_sql_type_for_integer_property(): # noqa: F811 + return Integer + + @staticmethod + def determine_sql_type_for_reference_property(): # noqa: F811 + return Text + + @staticmethod + def determine_sql_type_for_string_property(): # noqa: F811 + return Text + + @staticmethod + def determine_sql_type_for_timestamp_property(): # noqa: F811 + return TIMESTAMP(timezone=True) + + @staticmethod + def determine_sql_type_for_key_as_int(): # noqa: F811 + return Integer + + + @staticmethod + def determine_sql_type_for_key_as_id(): # noqa: F811 + return Text + + @staticmethod + def array_allowed(): + return False + + diff --git a/stix2/datastore/relational_db/database_backends/postgres_backend.py b/stix2/datastore/relational_db/database_backends/postgres_backend.py index b9672f97..08ca3c9e 100644 --- a/stix2/datastore/relational_db/database_backends/postgres_backend.py +++ b/stix2/datastore/relational_db/database_backends/postgres_backend.py @@ -1,13 +1,26 @@ import os from typing import Any from sqlalchemy.schema import CreateSchema +from sqlalchemy import ( # create_engine,; insert, + ARRAY, TIMESTAMP, Boolean, CheckConstraint, Column, Float, ForeignKey, + Integer, LargeBinary, Table, Text, UniqueConstraint, +) -from .database_backend_base import DatabaseBackend +from stix2.datastore.relational_db.utils import schema_for from stix2.base import ( - _DomainObject, _MetaObject, _Observable, _RelationshipObject, _STIXBase, + _DomainObject, _Extension, _MetaObject, _Observable, _RelationshipObject, _STIXBase, +) + +from stix2.properties import ( + BinaryProperty, BooleanProperty, DictionaryProperty, + EmbeddedObjectProperty, EnumProperty, ExtensionsProperty, FloatProperty, + HashesProperty, HexProperty, IDProperty, IntegerProperty, ListProperty, + ObjectReferenceProperty, Property, ReferenceProperty, StringProperty, + TimestampProperty, TypeProperty, ) +from .database_backend_base import DatabaseBackend class PostgresBackend(DatabaseBackend): default_database_connection_url = \ @@ -35,4 +48,15 @@ def _determine_schema_name(stix_object): elif isinstance(stix_object, _RelationshipObject): return "sro" elif isinstance(stix_object, _MetaObject): - return "common" \ No newline at end of file + return "common" + + def schema_for(self, stix_class): + return schema_for(stix_class) + + + @staticmethod + def array_allowed(): + return True + + + diff --git a/stix2/datastore/relational_db/relational_db.py b/stix2/datastore/relational_db/relational_db.py index 3e09d33b..6fbb027d 100644 --- a/stix2/datastore/relational_db/relational_db.py +++ b/stix2/datastore/relational_db/relational_db.py @@ -83,7 +83,7 @@ def __init__( self.metadata = MetaData() create_table_objects( - self.metadata, stix_object_classes, + self.metadata, db_backend, stix_object_classes, ) super().__init__( @@ -251,11 +251,11 @@ def __init__( else: self.metadata = MetaData() create_table_objects( - self.metadata, stix_object_classes, + self.metadata, db_backend, stix_object_classes, ) def get(self, stix_id, version=None, _composite_filters=None): - with self.db_backend.connect() as conn: + with self.db_backend.database_connection.connect() as conn: stix_obj = read_object( stix_id, self.metadata, diff --git a/stix2/datastore/relational_db/relational_db_testing.py b/stix2/datastore/relational_db/relational_db_testing.py index aadb3a0b..fe35a454 100644 --- a/stix2/datastore/relational_db/relational_db_testing.py +++ b/stix2/datastore/relational_db/relational_db_testing.py @@ -194,61 +194,61 @@ def custom_obj(): return obj -@stix2.CustomObject( - "test-object", [ - ("prop_name", stix2.properties.ListProperty(stix2.properties.BinaryProperty())) - ], - "extension-definition--15de9cdb-3515-4271-8479-8141154c5647", - is_sdo=True -) -class TestClass: - pass - - -def test_binary_list(): - return TestClass(prop_name=["AREi", "7t3M"]) - -@stix2.CustomObject( - "test2-object", [ - ("prop_name", stix2.properties.ListProperty( - stix2.properties.HexProperty() - )) - ], - "extension-definition--15de9cdb-4567-4271-8479-8141154c5647", - is_sdo=True - ) - -class Test2Class: - pass - -def test_hex_list(): - return Test2Class( - prop_name=["1122", "fedc"] - ) - -@stix2.CustomObject( - "test3-object", [ - ("prop_name", - stix2.properties.DictionaryProperty( - valid_types=[ - stix2.properties.IntegerProperty, - stix2.properties.FloatProperty, - stix2.properties.StringProperty - ] - ) - ) - ], - "extension-definition--15de9cdb-1234-4271-8479-8141154c5647", - is_sdo=True - ) -class Test3Class: - pass - - -def test_dictionary(): - return Test3Class( - prop_name={"a": 1, "b": 2.3, "c": "foo"} - ) +# @stix2.CustomObject( +# "test-object", [ +# ("prop_name", stix2.properties.ListProperty(stix2.properties.BinaryProperty())) +# ], +# "extension-definition--15de9cdb-3515-4271-8479-8141154c5647", +# is_sdo=True +# ) +# class TestClass: +# pass +# +# +# def test_binary_list(): +# return TestClass(prop_name=["AREi", "7t3M"]) +# +# @stix2.CustomObject( +# "test2-object", [ +# ("prop_name", stix2.properties.ListProperty( +# stix2.properties.HexProperty() +# )) +# ], +# "extension-definition--15de9cdb-4567-4271-8479-8141154c5647", +# is_sdo=True +# ) +# +# class Test2Class: +# pass +# +# def test_hex_list(): +# return Test2Class( +# prop_name=["1122", "fedc"] +# ) +# +# @stix2.CustomObject( +# "test3-object", [ +# ("prop_name", +# stix2.properties.DictionaryProperty( +# valid_types=[ +# stix2.properties.IntegerProperty, +# stix2.properties.FloatProperty, +# stix2.properties.StringProperty +# ] +# ) +# ) +# ], +# "extension-definition--15de9cdb-1234-4271-8479-8141154c5647", +# is_sdo=True +# ) +# class Test3Class: +# pass +# +# +# def test_dictionary(): +# return Test3Class( +# prop_name={"a": 1, "b": 2.3, "c": "foo"} +# ) def main(): @@ -263,17 +263,17 @@ def main(): store.sink.generate_stix_schema() store.sink.clear_tables() - td = test_dictionary() + # td = test_dictionary() - store.add(td) + # store.add(td) - th = test_hex_list() + # th = test_hex_list() # store.add(th) - tb = test_binary_list() + # tb = test_binary_list() - store.add(tb) + # store.add(tb) diff --git a/stix2/datastore/relational_db/table_creation.py b/stix2/datastore/relational_db/table_creation.py index fc3e6b5d..05b0e97d 100644 --- a/stix2/datastore/relational_db/table_creation.py +++ b/stix2/datastore/relational_db/table_creation.py @@ -8,8 +8,8 @@ from stix2.datastore.relational_db.add_method import add_method from stix2.datastore.relational_db.utils import ( SCO_COMMON_PROPERTIES, SDO_COMMON_PROPERTIES, canonicalize_table_name, - determine_column_name, determine_sql_type_from_class, flat_classes, - get_stix_object_classes, schema_for, + determine_column_name, determine_sql_type_from_stix, flat_classes, + get_stix_object_classes, ) from stix2.properties import ( BinaryProperty, BooleanProperty, DictionaryProperty, @@ -19,7 +19,7 @@ TimestampProperty, TypeProperty, ) from stix2.v21.base import _Extension -from stix2.v21.common import KillChainPhase +from stix2.v21.common import KillChainPhase, GranularMarking def aux_table_property(prop, name, core_properties): @@ -40,9 +40,10 @@ def derive_column_name(prop): return "value" -def create_object_markings_refs_table(metadata, sco_or_sdo): +def create_object_markings_refs_table(metadata, db_backend, sco_or_sdo): return create_ref_table( metadata, + db_backend, {"marking-definition"}, "object_marking_refs_" + sco_or_sdo, "common.core_" + sco_or_sdo + ".id", @@ -51,12 +52,12 @@ def create_object_markings_refs_table(metadata, sco_or_sdo): ) -def create_ref_table(metadata, specifics, table_name, foreign_key_name, schema_name, auth_type=0): +def create_ref_table(metadata, db_backend, specifics, table_name, foreign_key_name, schema_name, auth_type=0): columns = list() columns.append( Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey( foreign_key_name, ondelete="CASCADE", @@ -64,11 +65,11 @@ def create_ref_table(metadata, specifics, table_name, foreign_key_name, schema_n nullable=False, ), ) - columns.append(ref_column("ref_id", specifics, auth_type)) + columns.append(ref_column("ref_id", specifics, db_backend, auth_type)) return Table(table_name, metadata, *columns, schema=schema_name) -def create_hashes_table(name, metadata, schema_name, table_name, key_type=Text, level=1): +def create_hashes_table(name, metadata, db_backend, schema_name, table_name, key_type=Text, level=1): columns = list() # special case, perhaps because its a single embedded object with hashes, and not a list of embedded object # making the parent table's primary key does seem to worl @@ -88,14 +89,14 @@ def create_hashes_table(name, metadata, schema_name, table_name, key_type=Text, columns.append( Column( "hash_name", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) columns.append( Column( "hash_value", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) @@ -108,12 +109,12 @@ def create_hashes_table(name, metadata, schema_name, table_name, key_type=Text, ) -def create_kill_chain_phases_table(name, metadata, schema_name, table_name): +def create_kill_chain_phases_table(name, metadata, db_backend, schema_name, table_name): columns = list() columns.append( Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey( canonicalize_table_name(table_name, schema_name) + ".id", ondelete="CASCADE", @@ -124,93 +125,120 @@ def create_kill_chain_phases_table(name, metadata, schema_name, table_name): columns.append( Column( "kill_chain_name", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) columns.append( Column( "phase_name", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) return Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name) -def create_granular_markings_table(metadata, sco_or_sdo): - return Table( - "granular_marking_" + sco_or_sdo, - metadata, +def create_granular_markings_table(metadata, db_backend, sco_or_sdo): + tables = list() + columns = [ Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey("common.core_" + sco_or_sdo + ".id", ondelete="CASCADE"), nullable=False, ), - Column("lang", Text), + Column("lang", db_backend.determine_sql_type_for_string_property()), Column( "marking_ref", - Text, + db_backend.determine_sql_type_for_reference_property(), CheckConstraint( - "marking_ref ~ '^marking-definition--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 + "marking_ref ~ '^marking-definition--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", + # noqa: E131 ), - ), - Column( - "selectors", - ARRAY(Text), - CheckConstraint("array_length(selectors, 1) IS NOT NULL"), - nullable=False, - ), - CheckConstraint( - """(lang IS NULL AND marking_ref IS NOT NULL) - OR - (lang IS NOT NULL AND marking_ref IS NULL)""", - ), - schema="common", - ) + ) + ] + if db_backend.array_allowed(): + columns.append( + Column( + "selectors", + ARRAY(Text), + CheckConstraint("array_length(selectors, 1) IS NOT NULL"), + nullable=False, + )) + else: + table_name = "granular_marking_" + sco_or_sdo + schema_name = determine_sql_type_from_stix(GranularMarking, db_backend) + columns = [ + Column( + "id", + db_backend.determine_sql_type_for_key_as_id(), + ForeignKey( + canonicalize_table_name(table_name, schema_name) + ".id", + ondelete="CASCADE", + ), + nullable=False, + ), + Column( + "selector", + db_backend.determine_sql_type_for_string_property(), + nullable=False, + ) + ] + tables.append(Table(canonicalize_table_name(table_name + "_" + "selector"), metadata, *columns, schema=schema_name)) + tables.append(Table( + "granular_marking_" + sco_or_sdo, + metadata, + *columns, + CheckConstraint( + """(lang IS NULL AND marking_ref IS NOT NULL) + OR + (lang IS NOT NULL AND marking_ref IS NULL)""", + ), + schema="common")) + return tables -def create_external_references_tables(metadata): +def create_external_references_tables(metadata, db_backend): columns = [ Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey("common.core_sdo" + ".id", ondelete="CASCADE"), CheckConstraint( "id ~ '^[a-z][a-z0-9-]+[a-z0-9]--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 ), ), - Column("source_name", Text), - Column("description", Text), - Column("url", Text), - Column("external_id", Text), + Column("source_name", db_backend.determine_sql_type_for_string_property()), + Column("description", db_backend.determine_sql_type_for_string_property()), + Column("url", db_backend.determine_sql_type_for_string_property()), + Column("external_id", db_backend.determine_sql_type_for_string_property()), # all such keys are generated using the global sequence. - Column("hash_ref_id", Integer, primary_key=True, autoincrement=False), + Column("hash_ref_id", db_backend.determine_sql_type_for_key_as_int(), primary_key=True, autoincrement=False), ] return [ Table("external_references", metadata, *columns, schema="common"), - create_hashes_table("hashes", metadata, "common", "external_references", Integer), + create_hashes_table("hashes", metadata, db_backend, "common", "external_references", Integer), ] -def create_core_table(metadata, schema_name): +def create_core_table(metadata, db_backend, schema_name): columns = [ Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), CheckConstraint( "id ~ '^[a-z][a-z0-9-]+[a-z0-9]--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 ), primary_key=True, ), - Column("spec_version", Text, default="2.1"), + Column("spec_version", db_backend.determine_sql_type_for_string_property(), default="2.1"), ] if schema_name == "sdo": sdo_columns = [ Column( "created_by_ref", - Text, + db_backend.determine_sql_type_for_reference_property(), CheckConstraint( "created_by_ref ~ '^identity--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 ), @@ -219,12 +247,13 @@ def create_core_table(metadata, schema_name): Column("modified", TIMESTAMP(timezone=True)), Column("revoked", Boolean), Column("confidence", Integer), - Column("lang", Text), + Column("lang", db_backend.determine_sql_type_for_string_property()), Column("labels", ARRAY(Text)), ] columns.extend(sdo_columns) else: - columns.append(Column("defanged", Boolean, default=False)), + columns.append(Column("defanged", db_backend.determine_sql_type_for_boolean_property(), default=False)) + return Table( "core_" + schema_name, metadata, @@ -234,60 +263,60 @@ def create_core_table(metadata, schema_name): @add_method(Property) -def determine_sql_type(self): # noqa: F811 +def determine_sql_type(self, db_backend): # noqa: F811 pass @add_method(KillChainPhase) -def determine_sql_type(self): # noqa: F811 - return None +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_kill_chain_phase() @add_method(BinaryProperty) -def determine_sql_type(self): # noqa: F811 - return Text +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_binary_property() @add_method(BooleanProperty) -def determine_sql_type(self): # noqa: F811 - return Boolean +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_boolean_property() @add_method(FloatProperty) -def determine_sql_type(self): # noqa: F811 - return Float +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_float_property() @add_method(HexProperty) -def determine_sql_type(self): # noqa: F811 - return LargeBinary +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_hex_property() @add_method(IntegerProperty) -def determine_sql_type(self): # noqa: F811 - return Integer +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_integer_property() @add_method(ReferenceProperty) -def determine_sql_type(self): # noqa: F811 - return Text +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_reference_property() @add_method(StringProperty) -def determine_sql_type(self): # noqa: F811 - return Text +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_string_property() @add_method(TimestampProperty) -def determine_sql_type(self): # noqa: F811 - return TIMESTAMP(timezone=True) +def determine_sql_type(self, db_backend): # noqa: F811 + return db_backend.determine_sql_type_for_timestamp_property() # ----------------------------- generate_table_information methods ---------------------------- @add_method(KillChainPhase) def generate_table_information( # noqa: F811 - self, name, metadata, schema_name, table_name, is_extension=False, is_list=False, + self, name, db_backend, metadata, schema_name, table_name, is_extension=False, is_list=False, **kwargs, ): level = kwargs.get("level") @@ -298,15 +327,15 @@ def generate_table_information( # noqa: F811 @add_method(Property) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 pass @add_method(BinaryProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - Text, + self.determine_sql_type(db_backend), CheckConstraint( # this regular expression might accept or reject some legal base64 strings f"{name} ~ " + "'^[-A-Za-z0-9+/]*={0,3}$'", @@ -316,30 +345,30 @@ def generate_table_information(self, name, **kwargs): # noqa: F811 @add_method(BooleanProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - Boolean, + self.determine_sql_type(db_backend), nullable=not self.required, default=self._fixed_value if hasattr(self, "_fixed_value") else None, ) @add_method(DictionaryProperty) -def generate_table_information(self, name, metadata, schema_name, table_name, is_extension=False, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, is_extension=False, **kwargs): # noqa: F811 columns = list() columns.append( Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey(canonicalize_table_name(table_name, schema_name) + ".id", ondelete="CASCADE"), ), ) columns.append( Column( "name", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) @@ -350,7 +379,7 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is Column( "value", # its a class - determine_sql_type_from_class(self.valid_types[0]), + determine_sql_type_from_stix(self.valid_types[0], db_backend), nullable=False, ), ) @@ -360,13 +389,13 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is Column( "value", # its an instance, not a class - ARRAY(contained_class.determine_sql_type()), + ARRAY(contained_class.determine_sql_type(db_backend)), nullable=False, ), ) else: for column_type in self.valid_types: - sql_type = determine_sql_type_from_class(column_type) + sql_type = determine_sql_type_from_stix(column_type, db_backend) columns.append( Column( determine_column_name(column_type), @@ -377,7 +406,7 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is columns.append( Column( "value", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) @@ -392,24 +421,21 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is ] - - - @add_method(EmbeddedObjectProperty) -def generate_table_information(self, name, metadata, schema_name, table_name, is_extension=False, is_list=False, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, is_extension=False, is_list=False, **kwargs): # noqa: F811 level = kwargs.get("level") return generate_object_table( - self.type, metadata, schema_name, table_name, is_extension, True, is_list, + self.type, db_backend, metadata, schema_name, table_name, is_extension, True, is_list, parent_table_name=table_name, level=level+1 if is_list else level, ) @add_method(EnumProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 enum_re = "|".join(self.allowed) return Column( name, - Text, + self.determine_sql_type(db_backend), CheckConstraint( f"{name} ~ '^{enum_re}$'", ), @@ -418,12 +444,12 @@ def generate_table_information(self, name, **kwargs): # noqa: F811 @add_method(ExtensionsProperty) -def generate_table_information(self, name, metadata, schema_name, table_name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, **kwargs): # noqa: F811 columns = list() columns.append( Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey(canonicalize_table_name(table_name, schema_name) + ".id", ondelete="CASCADE"), nullable=False, ), @@ -431,7 +457,7 @@ def generate_table_information(self, name, metadata, schema_name, table_name, ** columns.append( Column( "ext_table_name", - Text, + db_backend.determine_sql_type_for_string_property(), nullable=False, ), ) @@ -439,17 +465,17 @@ def generate_table_information(self, name, metadata, schema_name, table_name, ** @add_method(FloatProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - Float, + self.determine_sql_type(db_backend), nullable=not self.required, default=self._fixed_value if hasattr(self, "_fixed_value") else None, ) @add_method(HashesProperty) -def generate_table_information(self, name, metadata, schema_name, table_name, is_extension=False, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, is_extension=False, **kwargs): # noqa: F811 level = kwargs.get("level") if kwargs.get("is_embedded_object"): if not kwargs.get("is_list") or level == 0: @@ -462,6 +488,7 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is create_hashes_table( name, metadata, + db_backend, schema_name, table_name, key_type=key_type, @@ -471,21 +498,19 @@ def generate_table_information(self, name, metadata, schema_name, table_name, is @add_method(HexProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - LargeBinary, + db_backend.determine_sql_type_for_hex_property(), nullable=not self.required, ) @add_method(IDProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 schema_name = kwargs.get('schema_name') - if schema_name in ["sro", "common"]: - # sro, smo common properties are the same as sdo's - schema_name = "sdo" table_name = kwargs.get("table_name") + core_table = kwargs.get("core_table") # if schema_name == "common": # return Column( # name, @@ -498,10 +523,13 @@ def generate_table_information(self, name, **kwargs): # noqa: F811 # nullable=not (self.required), # ) # else: - foreign_key_column = f"common.core_{schema_name}.id" + if schema_name: + foreign_key_column = f"common.core_{core_table}.id" + else: + foreign_key_column = f"core_{core_table}.id" return Column( name, - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey(foreign_key_column, ondelete="CASCADE"), CheckConstraint( f"{name} ~ '^{table_name}" + "--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", @@ -511,63 +539,28 @@ def generate_table_information(self, name, **kwargs): # noqa: F811 nullable=not (self.required), ) - return Column( - name, - Text, - ForeignKey(foreign_key_column, ondelete="CASCADE"), - CheckConstraint( - f"{name} ~ '^{table_name}" + "--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$'", # noqa: E131 - ), - primary_key=True, - nullable=not (self.required), - ) - @add_method(IntegerProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - Integer, + self.determine_sql_type(db_backend), nullable=not self.required, default=self._fixed_value if hasattr(self, "_fixed_value") else None, ) @add_method(ListProperty) -def generate_table_information(self, name, metadata, schema_name, table_name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, metadata, schema_name, table_name, **kwargs): # noqa: F811 is_extension = kwargs.get('is_extension') tables = list() - if isinstance(self.contained, ReferenceProperty): - return [ - create_ref_table( - metadata, - self.contained.specifics, - canonicalize_table_name(table_name + "_" + name), - canonicalize_table_name(table_name, schema_name) + ".id", - schema_name, - ), - ] - elif isinstance(self.contained, EnumProperty): + # handle more complext embedded object before deciding if the ARRAY type is usable + if isinstance(self.contained, EmbeddedObjectProperty): columns = list() columns.append( Column( "id", - Text, - ForeignKey( - canonicalize_table_name(table_name, schema_name) + ".id", - ondelete="CASCADE", - ), - nullable=False, - ), - ) - columns.append(self.contained.generate_table_information(name)) - tables.append(Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name)) - elif isinstance(self.contained, EmbeddedObjectProperty): - columns = list() - columns.append( - Column( - "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey( canonicalize_table_name(table_name, schema_name) + ".id", ondelete="CASCADE", @@ -577,7 +570,7 @@ def generate_table_information(self, name, metadata, schema_name, table_name, ** columns.append( Column( "ref_id", - Integer, + db_backend.determine_sql_type_for_key_as_int(), primary_key=True, nullable=False, # all such keys are generated using the global sequence. @@ -588,9 +581,11 @@ def generate_table_information(self, name, metadata, schema_name, table_name, ** tables.extend( self.contained.generate_table_information( name, + db_backend, metadata, schema_name, - canonicalize_table_name(table_name + "_" + name, None), # if sub_table_needed else canonicalize_table_name(table_name, None), + canonicalize_table_name(table_name + "_" + name, None), + # if sub_table_needed else canonicalize_table_name(table_name, None), is_extension, parent_table_name=table_name, is_list=True, @@ -598,12 +593,41 @@ def generate_table_information(self, name, metadata, schema_name, table_name, ** ), ) return tables + elif isinstance(self.contained, ReferenceProperty): + return [ + create_ref_table( + metadata, + db_backend, + self.contained.specifics, + canonicalize_table_name(table_name + "_" + name), + canonicalize_table_name(table_name, schema_name) + ".id", + schema_name, + ), + ] + elif ((isinstance(self.contained, + (StringProperty, IntegerProperty, FloatProperty)) and not db_backend.array_allowed()) or + isinstance(self.contained, EnumProperty)): + columns = list() + columns.append( + Column( + "id", + self.contained.determine_sql_type(db_backend), + ForeignKey( + canonicalize_table_name(table_name, schema_name) + ".id", + ondelete="CASCADE", + ), + nullable=False, + ), + ) + columns.append(self.contained.generate_table_information(name, db_backend)) + tables.append(Table(canonicalize_table_name(table_name + "_" + name), metadata, *columns, schema=schema_name)) + elif self.contained == KillChainPhase: - tables.append(create_kill_chain_phases_table(name, metadata, schema_name, table_name)) + tables.append(create_kill_chain_phases_table(name, metadata, db_backend, schema_name, table_name)) return tables else: if isinstance(self.contained, Property): - sql_type = self.contained.determine_sql_type() + sql_type = self.contained.determine_sql_type(db_backend) if sql_type: return Column( name, @@ -612,7 +636,7 @@ def generate_table_information(self, name, metadata, schema_name, table_name, ** ) -def ref_column(name, specifics, auth_type=0): +def ref_column(name, specifics, db_backend, auth_type=0): if specifics: types = "|".join(specifics) if auth_type == 0: @@ -627,41 +651,41 @@ def ref_column(name, specifics, auth_type=0): f"(NOT({name} ~ '^({types})')) AND ({name} ~ " + "'--[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$')", ) - return Column(name, Text, constraint) + return Column(name, db_backend.determine_sql_type_for_reference_property(), constraint) else: return Column( name, - Text, + db_backend.determine_sql_type_for_reference_property(), nullable=False, ) @add_method(ObjectReferenceProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 table_name = kwargs.get('table_name') raise ValueError(f"Property {name} in {table_name} is of type ObjectReferenceProperty, which is for STIX 2.0 only") @add_method(ReferenceProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 - return ref_column(name, self.specifics, self.auth_type) +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 + return ref_column(name, self.specifics, db_backend, self.auth_type) @add_method(StringProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - Text, + db_backend.determine_sql_type_for_string_property(), nullable=not self.required, default=self._fixed_value if hasattr(self, "_fixed_value") else None, ) @add_method(TimestampProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - TIMESTAMP(timezone=True), + self.determine_sql_type(db_backend), # CheckConstraint( # f"{name} ~ '^{enum_re}$'" # ), @@ -670,17 +694,17 @@ def generate_table_information(self, name, **kwargs): # noqa: F811 @add_method(TypeProperty) -def generate_table_information(self, name, **kwargs): # noqa: F811 +def generate_table_information(self, name, db_backend, **kwargs): # noqa: F811 return Column( name, - Text, + db_backend.determine_sql_type_for_string_property(), nullable=not self.required, default=self._fixed_value if hasattr(self, "_fixed_value") else None, ) def generate_object_table( - stix_object_class, metadata, schema_name, foreign_key_name=None, + stix_object_class, db_backend, metadata, schema_name, foreign_key_name=None, is_extension=False, is_embedded_object=False, is_list=False, parent_table_name=None, level=0, ): properties = stix_object_class._properties @@ -704,11 +728,17 @@ def generate_object_table( core_properties = list() columns = list() tables = list() + if schema_name == "sco": + core_table = "sco" + else: + # sro, smo common properties are the same as sdo's + core_table = "sdo" for name, prop in properties.items(): # type is never a column since it is implicit in the table if (name == 'id' or name not in core_properties) and name != 'type': col = prop.generate_table_information( name, + db_backend, metadata=metadata, schema_name=schema_name, table_name=table_name, @@ -717,6 +747,7 @@ def generate_object_table( is_list=is_list, level=level, parent_table_name=parent_table_name, + core_table=core_table, ) if col is not None and isinstance(col, Column): columns.append(col) @@ -726,7 +757,7 @@ def generate_object_table( columns.append( Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), # no Foreign Key because it could be for different tables primary_key=True, ), @@ -736,7 +767,7 @@ def generate_object_table( if is_extension and not is_embedded_object: column = Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey( canonicalize_table_name(foreign_key_name, schema_name) + ".id", ondelete="CASCADE", @@ -745,7 +776,7 @@ def generate_object_table( elif is_embedded_object: column = Column( "id", - Integer if is_list else Text, + db_backend.determine_sql_type_for_key_as_int() if is_list else db_backend.determine_sql_type_for_key_as_id(), ForeignKey( canonicalize_table_name(foreign_key_name, schema_name) + (".ref_id" if is_list else ".id"), ondelete="CASCADE", @@ -756,7 +787,7 @@ def generate_object_table( elif level > 0 and is_embedded_object: column = Column( "id", - Integer if (is_embedded_object and is_list) else Text, + db_backend.determine_sql_type_for_key_as_int() if (is_embedded_object and is_list) else db_backend.determine_sql_type_for_key_as_id(), ForeignKey( canonicalize_table_name(foreign_key_name, schema_name) + (".ref_id" if (is_embedded_object and is_list) else ".id"), ondelete="CASCADE", @@ -767,7 +798,7 @@ def generate_object_table( else: column = Column( "id", - Text, + db_backend.determine_sql_type_for_key_as_id(), ForeignKey( canonicalize_table_name(foreign_key_name, schema_name) + ".id", ondelete="CASCADE", @@ -783,20 +814,20 @@ def generate_object_table( return tables -def create_core_tables(metadata): +def create_core_tables(metadata, db_backend): tables = [ - create_core_table(metadata, "sdo"), - create_granular_markings_table(metadata, "sdo"), - create_core_table(metadata, "sco"), - create_granular_markings_table(metadata, "sco"), - create_object_markings_refs_table(metadata, "sdo"), - create_object_markings_refs_table(metadata, "sco"), + create_core_table(metadata, db_backend, "sdo"), + create_granular_markings_table(metadata, db_backend, "sdo"), + create_core_table(metadata, db_backend, "sco"), + create_granular_markings_table(metadata, db_backend, "sco"), + create_object_markings_refs_table(metadata, db_backend, "sdo"), + create_object_markings_refs_table(metadata, db_backend, "sco"), ] - tables.extend(create_external_references_tables(metadata)) + tables.extend(create_external_references_tables(metadata, db_backend)) return tables -def create_table_objects(metadata, stix_object_classes): +def create_table_objects(metadata, db_backend, stix_object_classes): if stix_object_classes: # If classes are given, allow some flexibility regarding lists of # classes vs single classes @@ -806,16 +837,17 @@ def create_table_objects(metadata, stix_object_classes): # If no classes given explicitly, discover them automatically stix_object_classes = get_stix_object_classes() - tables = create_core_tables(metadata) + tables = create_core_tables(metadata, db_backend) for stix_class in stix_object_classes: - schema_name = schema_for(stix_class) + schema_name = db_backend.schema_for(stix_class) is_extension = issubclass(stix_class, _Extension) tables.extend( generate_object_table( stix_class, + db_backend, metadata, schema_name, is_extension=is_extension, diff --git a/stix2/datastore/relational_db/utils.py b/stix2/datastore/relational_db/utils.py index 5d06de7b..76a4268b 100644 --- a/stix2/datastore/relational_db/utils.py +++ b/stix2/datastore/relational_db/utils.py @@ -136,23 +136,25 @@ def is_class_or_instance(cls_or_inst, cls): return cls_or_inst == cls or isinstance(cls_or_inst, cls) -def determine_sql_type_from_class(cls_or_inst): # noqa: F811 +def determine_sql_type_from_stix(cls_or_inst, db_backend): # noqa: F811 if is_class_or_instance(cls_or_inst, BinaryProperty): - return LargeBinary + return db_backend.determine_sql_type_for_binary_property() elif is_class_or_instance(cls_or_inst, BooleanProperty): - return Boolean - elif is_class_or_instance(cls_or_inst, FloatProperty ): - return Float + return db_backend.determine_sql_type_for_boolean_property() + elif is_class_or_instance(cls_or_inst, FloatProperty): + return db_backend.determine_sql_type_for_float_property() elif is_class_or_instance(cls_or_inst, HexProperty): - return LargeBinary + return db_backend.determine_sql_type_for_hex_property() elif is_class_or_instance(cls_or_inst, IntegerProperty): - return Integer - elif is_class_or_instance(cls_or_inst, StringProperty) or is_class_or_instance(cls_or_inst, ReferenceProperty): - return Text + return db_backend.determine_sql_type_for_integer_property() + elif is_class_or_instance(cls_or_inst, StringProperty): + return db_backend.determine_sql_type_for_integer_property() + elif is_class_or_instance(cls_or_inst, ReferenceProperty): + db_backend.determine_sql_type_for_reference_property() elif is_class_or_instance(cls_or_inst, TimestampProperty): - return TIMESTAMP(timezone=True) + return db_backend.determine_sql_type_for_timestamp_property() elif is_class_or_instance(cls_or_inst, Property): - return Text + return db_backend.determine_sql_type_for_integer_property() def determine_column_name(cls_or_inst): # noqa: F811