Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support a schema update mode in stats runner #344

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function run_lint_fix {
echo -e "#### Fixing Python code"
python3 -m venv .env
source .env/bin/activate
pip3 install yapf==0.33.0 -q
pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
then
pip3 install isort -q
Expand All @@ -35,12 +35,12 @@ function run_lint_fix {
function run_lint_test {
python3 -m venv .env
source .env/bin/activate
pip3 install yapf==0.33.0 -q
pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
then
pip3 install isort -q
fi

echo -e "#### Checking Python style"
if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=.env/*; then
echo "Fix Python lint errors by running ./run_test.sh -f"
Expand Down Expand Up @@ -74,9 +74,9 @@ function py_test {

python3 -m venv .env
source .env/bin/activate

cd simple
pip3 install -r requirements.txt
pip3 install -r requirements.txt -q

echo -e "#### Running stats tests"
python3 -m pytest tests/stats/ -s
Expand Down
2 changes: 1 addition & 1 deletion simple/run_stats.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Options:
-c <file> Json config file for stats importer
-i <dir> Input directory to process
-o <dir> Output folder for stats importer. Default: $OUTPUT_DIR
-m <customdc|maindc> Mode of operation for simple importer. Default: $MODE
-m <customdc|schemaupdate|maindc> Mode of operation for simple importer. Default: $MODE
-k <api-key> DataCommons API Key
-j <jar> DC Import java jar file.
Download latest from https://github.com/datacommonsorg/import/releases/
Expand Down
63 changes: 45 additions & 18 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@

_SELECT_ENTITY_NAMES = "select subject_id, object_value from triples where subject_id in (%s) and predicate = 'name' and object_value <> ''"

_INIT_STATEMENTS = [
_INIT_TABLE_STATEMENTS = [
_CREATE_TRIPLES_TABLE,
_CREATE_OBSERVATIONS_TABLE,
_CREATE_KEY_VALUE_STORE_TABLE,
_CREATE_IMPORTS_TABLE,
]

_CLEAR_TABLE_FOR_IMPORT_STATEMENTS = [
# Clearing tables for now (not the import tables though since we want to maintain its history).
_DELETE_TRIPLES_STATEMENT,
_DELETE_OBSERVATIONS_STATEMENT,
Expand Down Expand Up @@ -195,6 +198,9 @@ class Db:
The "DB" could be a traditional sql db or a file system with the output being files.
"""

def maybe_clear_before_import(self):
pass

def insert_triples(self, triples: list[Triple]):
pass

Expand Down Expand Up @@ -285,8 +291,13 @@ class SqlDb(Db):

def __init__(self, config: dict) -> None:
self.engine = create_db_engine(config)
self.engine.init_or_update_tables()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since clear_tables_and_indexes_for_import() is now being called explicitly by clients, I'm wondering if init_or_update_tables() should be as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so I tried this but it changed an unexpectedly large surface and seemed to make it easier to "hold it wrong" when creating any Db, even if it's not a SqlDb. I have tried instead to make it very clear with naming and comments what is happening. WDYT?

self.num_observations = 0
self.variables: set[str] = set()
self.indexes_cleared = False

def maybe_clear_before_import(self):
self.engine.clear_tables_and_indexes()

def insert_triples(self, triples: list[Triple]):
logging.info("Writing %s triples to [%s]", len(triples), self.engine)
Expand Down Expand Up @@ -345,6 +356,12 @@ def from_triple_tuple(tuple: tuple) -> Triple:

class DbEngine:

def init_or_update_tables(self):
pass

def clear_tables_and_indexes(self):
pass

def execute(self, sql: str, parameters=None):
pass

Expand Down Expand Up @@ -379,14 +396,8 @@ def __init__(self, db_params: dict) -> None:
logging.info("Connected to SQLite: %s", self.local_db_file_path)

self.cursor = self.connection.cursor()
# Drop indexes first so inserts are faster.
self._drop_indexes()
for statement in _INIT_STATEMENTS:
self.cursor.execute(statement)
# Apply schema updates.
self._schema_updates()

def _schema_updates(self) -> None:
def _maybe_update_schema(self) -> None:
"""
Add any sqlite schema updates here.
Ensure that all schema updates always check if the update is necessary before applying it.
Expand Down Expand Up @@ -415,6 +426,15 @@ def _create_indexes(self) -> None:
def __str__(self) -> str:
return f"{TYPE_SQLITE}: {self.db_file_path}"

def init_or_update_tables(self):
for statement in _INIT_TABLE_STATEMENTS:
self.cursor.execute(statement)
self._maybe_update_schema()

def clear_tables_and_indexes(self):
for statement in _CLEAR_TABLE_FOR_IMPORT_STATEMENTS:
self.cursor.execute(statement)

def execute(self, sql: str, parameters=None):
if not parameters:
self.cursor.execute(sql)
Expand Down Expand Up @@ -461,8 +481,8 @@ def commit_and_close(self):
_CLOUD_MY_SQL_PARAMS = [CLOUD_MY_SQL_INSTANCE] + _CLOUD_MY_SQL_DB_CONNECT_PARAMS

_CLOUD_MYSQL_PROPERTIES_COLUMN_EXISTS_STATEMENT = """
SELECT 1
FROM INFORMATION_SCHEMA.COLUMNS
SELECT 1
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'observations' AND COLUMN_NAME = 'properties';
"""

Expand All @@ -486,14 +506,8 @@ def __init__(self, db_params: dict[str, str]) -> None:
db_params[CLOUD_MY_SQL_INSTANCE], db_params[CLOUD_MY_SQL_DB])
self.description = f"{TYPE_CLOUD_SQL}: {db_params[CLOUD_MY_SQL_INSTANCE]} ({db_params[CLOUD_MY_SQL_DB]})"
self.cursor: Cursor = self.connection.cursor()
# Drop indexes first so inserts are faster.
self._drop_indexes()
for statement in _INIT_STATEMENTS:
self.cursor.execute(statement)
# Apply schema updates.
self._schema_updates()

def _schema_updates(self) -> None:
def _maybe_update_schema(self) -> None:
"""
Add any cloud sql schema updates here.
Ensure that all schema updates always check if the update is necessary before applying it.
Expand Down Expand Up @@ -555,6 +569,16 @@ def _db_exists(cursor) -> bool:
def __str__(self) -> str:
return self.description

def init_or_update_tables(self):
for statement in _INIT_TABLE_STATEMENTS:
self.cursor.execute(statement)
self._maybe_update_schema()

def clear_tables_and_indexes(self):
for statement in _CLEAR_TABLE_FOR_IMPORT_STATEMENTS:
self.cursor.execute(statement)
self._drop_indexes()

def execute(self, sql: str, parameters=None):
self.cursor.execute(_pymysql(sql), parameters)

Expand Down Expand Up @@ -599,7 +623,10 @@ def create_db_engine(config: dict) -> DbEngine:
assert False


def create_db(config: dict) -> Db:
def create_and_update_db(config: dict) -> Db:
""" Creates and initializes a Db, performing any setup and updates
(e.g. table creation, table schema changes) that are needed.
"""
db_type = config[FIELD_DB_TYPE]
if db_type and db_type == TYPE_MAIN_DC:
return MainDcDb(config)
Expand Down
95 changes: 56 additions & 39 deletions simple/stats/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from stats.data import ParentSVG2ChildSpecializedNames
from stats.data import Triple
from stats.data import VerticalSpec
from stats.db import create_db
from stats.db import create_and_update_db
from stats.db import create_main_dc_config
from stats.db import create_sqlite_config
from stats.db import get_cloud_sql_config_from_env
Expand All @@ -48,6 +48,7 @@

class RunMode(StrEnum):
CUSTOM_DC = "customdc"
SCHEMA_UPDATE = "schemaupdate"
MAIN_DC = "maindc"


Expand Down Expand Up @@ -113,59 +114,75 @@ def __init__(self,
self.reporter = ImportReporter(report_fh=self.process_dir_fh.make_file(
constants.REPORT_JSON_FILE_NAME))

# DB setup.
def _get_db_config() -> dict:
if self.mode == RunMode.MAIN_DC:
logging.info("Using Main DC config.")
return create_main_dc_config(self.output_dir_fh.path)
# Attempt to get from env (cloud sql, then sqlite),
# then config file, then default.
db_cfg = get_cloud_sql_config_from_env()
if db_cfg:
logging.info("Using Cloud SQL settings from env.")
return db_cfg
db_cfg = get_sqlite_config_from_env()
if db_cfg:
logging.info("Using SQLite settings from env.")
return db_cfg
logging.info("Using default DB settings.")
return create_sqlite_config(
self.output_dir_fh.make_file(constants.DB_FILE_NAME).path)

self.db = create_db(_get_db_config())
self.nodes = Nodes(self.config)
self.db = None

def run(self):
try:
# Run all data imports.
self._run_imports()
if (self.db is None):
self.db = create_and_update_db(self._get_db_config())

# Generate triples.
triples = self.nodes.triples()
# Write triples to DB.
self.db.insert_triples(triples)
if self.mode == RunMode.SCHEMA_UPDATE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we require clients to call init_or_update_tables() (from the other comment), this mode will need to call it.

nit: For clarity, consider implementing it in a separate _run_schema_update_mode() to have a named method for it and use it as a container for future changes.

Copy link
Contributor Author

@hqpho hqpho Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opted not to add _run_schema_update_mode since it makes it harder to tell what operations are shared and right now IMO there is not so much divergence that this method is hard to read. Can revisit in the future if there is more divergence!

logging.info("Skipping imports because run mode is schema update.")

# Generate SVG hierarchy.
self._generate_svg_hierarchy()
elif self.mode == RunMode.CUSTOM_DC or self.mode == RunMode.MAIN_DC:
self._run_imports_and_do_post_import_work()

# Generate SVG cache.
self._generate_svg_cache()

# Generate NL sentences for creating embeddings.
self._generate_nl_sentences()

# Write import info to DB.
self.db.insert_import_info(status=ImportStatus.SUCCESS)
else:
raise ValueError(f"Unsupported mode: {self.mode}")

# Commit and close DB.
self.db.commit_and_close()

# Report done.
self.reporter.report_done()
except Exception as e:
logging.exception("Error running import")
logging.exception("Error updating stats")
self.reporter.report_failure(error=str(e))

def _get_db_config(self) -> dict:
if self.mode == RunMode.MAIN_DC:
logging.info("Using Main DC config.")
return create_main_dc_config(self.output_dir_fh.path)
# Attempt to get from env (cloud sql, then sqlite),
# then config file, then default.
db_cfg = get_cloud_sql_config_from_env()
if db_cfg:
logging.info("Using Cloud SQL settings from env.")
return db_cfg
db_cfg = get_sqlite_config_from_env()
if db_cfg:
logging.info("Using SQLite settings from env.")
return db_cfg
logging.info("Using default DB settings.")
return create_sqlite_config(
self.output_dir_fh.make_file(constants.DB_FILE_NAME).path)

def _run_imports_and_do_post_import_work(self):
# (SQL only) Drop data in existing tables (except import metadata).
# Also drop indexes for faster writes.
self.db.maybe_clear_before_import()

# Import data from all input files.
self._run_all_data_imports()

# Generate triples.
triples = self.nodes.triples()
# Write triples to DB.
self.db.insert_triples(triples)

# Generate SVG hierarchy.
self._generate_svg_hierarchy()

# Generate SVG cache.
self._generate_svg_cache()

# Generate NL sentences for creating embeddings.
self._generate_nl_sentences()

# Write import info to DB.
self.db.insert_import_info(status=ImportStatus.SUCCESS)

def _generate_nl_sentences(self):
triples: list[Triple] = []
# Get topic triples if generating topics else get SV triples.
Expand Down Expand Up @@ -247,7 +264,7 @@ def _maybe_set_special_fh(self, fh: FileHandler) -> bool:
return True
return False

def _run_imports(self):
def _run_all_data_imports(self):
input_fhs: list[FileHandler] = []
input_mcf_fhs: list[FileHandler] = []
for input_handler in self.input_handlers:
Expand Down
Loading