Skip to content

Commit

Permalink
consistently make tables with schema argument, add sql-type pass methods
Browse files Browse the repository at this point in the history
  • Loading branch information
rpiazza committed Nov 13, 2024
1 parent 7df7b9e commit cf9c0c9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,24 @@ def schema_for(stix_class):
def schema_for_core():
return ""

# you must implement the next 4 methods in the subclass

@staticmethod
def determine_sql_type_for_property(): # noqa: F811
pass

@staticmethod
def determine_sql_type_for_binary_property(): # noqa: F811
pass

@staticmethod
def determine_sql_type_for_hex_property(): # noqa: F811
pass

@staticmethod
def determine_sql_type_for_timestamp_property(): # noqa: F811
pass

@staticmethod
def determine_sql_type_for_kill_chain_phase(): # noqa: F811
return None
Expand Down
12 changes: 7 additions & 5 deletions stix2/datastore/relational_db/table_creation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# from collections import OrderedDict

from sqlalchemy import ( # create_engine,; insert,
ARRAY, TIMESTAMP, Boolean, CheckConstraint, Column, Float, ForeignKey,
Integer, LargeBinary, Table, Text, UniqueConstraint,
ARRAY, TIMESTAMP, Boolean, CheckConstraint, Column, ForeignKey,
Integer, Table, Text, UniqueConstraint
)

from stix2.datastore.relational_db.add_method import add_method
Expand Down Expand Up @@ -60,7 +60,7 @@ def create_array_child_table(metadata, db_backend, table_name, property_name, co
nullable=False,
),
]
return Table(canonicalize_table_name(table_name + "_" + "selector", schema_name), metadata, *columns)
return Table(canonicalize_table_name(table_name + "_" + "selector"), metadata, *columns, schema=schema_name)


def derive_column_name(prop):
Expand Down Expand Up @@ -171,11 +171,12 @@ def create_kill_chain_phases_table(name, metadata, db_backend, schema_name, tabl


def create_granular_markings_table(metadata, db_backend, sco_or_sdo):
schema_name = db_backend.schema_for_core()
columns = [
Column(
"id",
db_backend.determine_sql_type_for_key_as_id(),
ForeignKey("common.core_" + sco_or_sdo + ".id", ondelete="CASCADE"),
ForeignKey(canonicalize_table_name("core_" + sco_or_sdo, schema_name) + ".id", ondelete="CASCADE"),
nullable=False,
primary_key=True,
),
Expand All @@ -197,14 +198,15 @@ def create_granular_markings_table(metadata, db_backend, sco_or_sdo):

tables = [
Table(
canonicalize_table_name("granular_marking_" + sco_or_sdo, db_backend.schema_for_core()),
canonicalize_table_name("granular_marking_" + sco_or_sdo),
metadata,
*columns,
CheckConstraint(
"""(lang IS NULL AND marking_ref IS NOT NULL)
OR
(lang IS NOT NULL AND marking_ref IS NULL)""",
),
schema=schema_name
),
]
if child_table:
Expand Down

0 comments on commit cf9c0c9

Please sign in to comment.