diff --git a/ord_schema/orm/database.py b/ord_schema/orm/database.py index c9e17165..133bd1d3 100644 --- a/ord_schema/orm/database.py +++ b/ord_schema/orm/database.py @@ -17,15 +17,13 @@ import time from unittest.mock import patch -from sqlalchemy import cast, delete, func, select, text, update -from sqlalchemy.dialects.postgresql import insert +from sqlalchemy import delete, select, text from sqlalchemy.engine import Engine from sqlalchemy.exc import NotSupportedError, OperationalError from sqlalchemy.orm import Session from ord_schema.logging import get_logger from ord_schema.orm.mappers import Base, Mappers, from_proto -from ord_schema.orm.rdkit_mappers import CString, FingerprintType, RDKitMols, RDKitReactions from ord_schema.proto import dataset_pb2 logger = get_logger(__name__) @@ -110,6 +108,7 @@ def delete_dataset(dataset_id: str, session: Session) -> None: def update_rdkit_tables(dataset_id: str, session: Session) -> None: """Updates RDKit PostgreSQL cartridge data.""" + logger.debug(f"Updating RDKit tables for {dataset_id=}") _update_rdkit_reactions(dataset_id, session) _update_rdkit_mols(dataset_id, session) @@ -117,84 +116,75 @@ def update_rdkit_tables(dataset_id: str, session: Session) -> None: def _update_rdkit_reactions(dataset_id: str, session: Session) -> None: """Updates the RDKit reactions table.""" logger.debug("Updating RDKit reactions") - assert hasattr(RDKitReactions, "__table__") # Type hint. - table = RDKitReactions.__table__ start = time.time() - session.execute( - insert(table) - .from_select( - ["reaction_smiles"], - select(Mappers.Reaction.reaction_smiles) - .join(Mappers.Dataset) - .where(Mappers.Dataset.dataset_id == dataset_id, Mappers.Reaction.reaction_smiles.is_not(None)) - .distinct(), - ) - .on_conflict_do_nothing(index_elements=["reaction_smiles"]) - ) - logger.debug(f"Updating reaction SMILES took {time.time() - start:g}s") - start = time.time() - session.execute( - update(table) - .where(table.c.reaction.is_(None)) - .values(reaction=func.reaction_from_smiles(cast(table.c.reaction_smiles, CString))) + result = session.execute( + text( + """ + INSERT INTO rdkit.reactions (reaction_smiles, reaction) + SELECT reaction_smiles, reaction_from_smiles(reaction_smiles::cstring) + FROM ( + SELECT reaction_smiles + FROM ord.reaction + JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id + WHERE ord.dataset.dataset_id = :dataset_id + AND ord.reaction.rdkit_reaction_id IS NULL + EXCEPT + SELECT reaction_smiles + FROM rdkit.reactions + ) subquery + """ + ), + {"dataset_id": dataset_id}, ) - logger.debug(f"reaction_from_smiles took {time.time() - start:g}s") + logger.debug(f"Updating reactions took {time.time() - start:g}s ({result.rowcount} rows)") def _update_rdkit_mols(dataset_id: str, session: Session) -> None: """Updates the RDKit mols table.""" logger.debug("Updating RDKit mols") - assert hasattr(RDKitMols, "__table__") # Type hint. - table = RDKitMols.__table__ start = time.time() - # NOTE(skearnes): This join path does not include non-input compounds like workups, internal standards, etc. - session.execute( - insert(table) - .from_select( - ["smiles"], - select(Mappers.Compound.smiles) - .join(Mappers.ReactionInput) - .join(Mappers.Reaction) - .join(Mappers.Dataset) - .where( - Mappers.Dataset.dataset_id == dataset_id, - Mappers.Compound.smiles.is_not(None), - # See https://github.com/open-reaction-database/ord-schema/issues/672. - Mappers.Compound.smiles.not_like("%[Ti+5]%"), - ) - .distinct(), - ) - .on_conflict_do_nothing(index_elements=["smiles"]) - ) - session.execute( - insert(table) - .from_select( - ["smiles"], - select(Mappers.ProductCompound.smiles) - .join(Mappers.ReactionOutcome) - .join(Mappers.Reaction) - .join(Mappers.Dataset) - .where(Mappers.Dataset.dataset_id == dataset_id, Mappers.ProductCompound.smiles.is_not(None)) - .distinct(), - ) - .on_conflict_do_nothing(index_elements=["smiles"]) + result = session.execute( + text( + """ + INSERT INTO rdkit.mols (smiles, mol, morgan_bfp, morgan_sfp) + SELECT smiles, mol, morgan_bfp, morgan_sfp + FROM ( + SELECT smiles, mol, morganbv_fp(mol) AS morgan_bfp, morgan_fp(mol) AS morgan_sfp + FROM ( + SELECT smiles, mol_from_smiles(smiles::cstring) AS mol + FROM ( + SELECT smiles + -- NOTE(skearnes): This join path does not include non-input compounds like workups, + -- internal standards, etc. + FROM ord.compound + JOIN ord.reaction_input ON ord.compound.reaction_input_id = ord.reaction_input.id + JOIN ord.reaction ON ord.reaction_input.reaction_id = ord.reaction.id + JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id + WHERE ord.dataset.dataset_id = :dataset_id + AND ord.compound.rdkit_mol_id IS NULL + UNION + SELECT smiles + FROM ord.product_compound + JOIN ord.reaction_outcome + ON ord.product_compound.reaction_outcome_id = ord.reaction_outcome.id + JOIN ord.reaction ON ord.reaction_outcome.reaction_id = ord.reaction.id + JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id + WHERE ord.dataset.dataset_id = :dataset_id + AND ord.product_compound.rdkit_mol_id IS NULL + EXCEPT + SELECT smiles + FROM rdkit.mols + ) smiles_subquery + -- See https://github.com/open-reaction-database/ord-schema/issues/672. + WHERE smiles NOT LIKE '%[Ti+5]%' + ) mol_subquery + ) fp_subquery + ON CONFLICT (smiles) DO NOTHING + """ + ), + {"dataset_id": dataset_id}, ) - logger.debug(f"Updating SMILES took {time.time() - start:g}s") - start = time.time() - session.execute( - update(table).where(table.c.mol.is_(None)).values(mol=func.mol_from_smiles(cast(table.c.smiles, CString))) - ) - logger.debug(f"mol_from_smiles took {time.time() - start:g}s") - logger.debug("Updating fingerprints") - for fp_type in FingerprintType: - start = time.time() - column = fp_type.name.lower() - session.execute( - update(table) - .where(getattr(table.c, column).is_(None), table.c.mol.is_not(None)) - .values(**{column: fp_type(table.c.mol)}) - ) - logger.debug(f"{fp_type} took {time.time() - start:g}s") + logger.debug(f"Updating mols took {time.time() - start:g}s ({result.rowcount} rows)") def update_rdkit_ids(dataset_id: str, session: Session) -> None: @@ -202,40 +192,63 @@ def update_rdkit_ids(dataset_id: str, session: Session) -> None: logger.debug("Updating RDKit ID associations") start = time.time() # Update Reaction. - query = session.execute( - select(Mappers.Reaction.id, RDKitReactions.id) - .join(RDKitReactions, Mappers.Reaction.reaction_smiles == RDKitReactions.reaction_smiles) - .join(Mappers.Dataset) - .where(Mappers.Dataset.dataset_id == dataset_id) + session.execute( + text( + """ + UPDATE ord.reaction + SET rdkit_reaction_id = subquery.rdkit_reaction_id + FROM ( + SELECT ord.reaction.id, rdkit.reactions.id AS rdkit_reaction_id + FROM ord.reaction + JOIN rdkit.reactions USING (reaction_smiles) + JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id + WHERE ord.dataset.dataset_id = :dataset_id + AND ord.reaction.rdkit_reaction_id IS NULL + ) AS subquery + WHERE ord.reaction.id = subquery.id + """ + ), + {"dataset_id": dataset_id}, ) - updates = [] - for ord_id, rdkit_id in query.fetchall(): - updates.append({"id": ord_id, "rdkit_reaction_id": rdkit_id}) - session.execute(update(Mappers.Reaction), updates) # Update Compound. - query = session.execute( - select(Mappers.Compound.id, RDKitMols.id) - .join(RDKitMols, Mappers.Compound.smiles == RDKitMols.smiles) - .join(Mappers.ReactionInput) - .join(Mappers.Reaction) - .join(Mappers.Dataset) - .where(Mappers.Dataset.dataset_id == dataset_id) + session.execute( + text( + """ + UPDATE ord.compound + SET rdkit_mol_id = subquery.rdkit_mol_id + FROM ( + SELECT ord.compound.id, rdkit.mols.id AS rdkit_mol_id + FROM ord.compound + JOIN rdkit.mols USING (smiles) + JOIN ord.reaction_input ON ord.compound.reaction_input_id = ord.reaction_input.id + JOIN ord.reaction ON ord.reaction_input.reaction_id = ord.reaction.id + JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id + WHERE ord.dataset.dataset_id = :dataset_id + AND ord.compound.rdkit_mol_id IS NULL + ) AS subquery + WHERE ord.compound.id = subquery.id + """ + ), + {"dataset_id": dataset_id}, ) - updates = [] - for ord_id, rdkit_id in query.fetchall(): - updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id}) - session.execute(update(Mappers.Compound), updates) - # Update ProductCompound. - query = session.execute( - select(Mappers.ProductCompound.id, RDKitMols.id) - .join(RDKitMols, Mappers.ProductCompound.smiles == RDKitMols.smiles) - .join(Mappers.ReactionOutcome) - .join(Mappers.Reaction) - .join(Mappers.Dataset) - .where(Mappers.Dataset.dataset_id == dataset_id) + session.execute( + text( + """ + UPDATE ord.product_compound + SET rdkit_mol_id = subquery.rdkit_mol_id + FROM ( + SELECT ord.product_compound.id, rdkit.mols.id AS rdkit_mol_id + FROM ord.product_compound + JOIN rdkit.mols USING (smiles) + JOIN ord.reaction_outcome ON ord.product_compound.reaction_outcome_id = ord.reaction_outcome.id + JOIN ord.reaction ON ord.reaction_outcome.reaction_id = ord.reaction.id + JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id + WHERE ord.dataset.dataset_id = :dataset_id + AND ord.product_compound.rdkit_mol_id IS NULL + ) AS subquery + WHERE ord.product_compound.id = subquery.id + """ + ), + {"dataset_id": dataset_id}, ) - updates = [] - for ord_id, rdkit_id in query.fetchall(): - updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id}) - session.execute(update(Mappers.ProductCompound), updates) logger.debug(f"Updating RDKit IDs took {time.time() - start:g}s") diff --git a/ord_schema/orm/scripts/add_datasets.py b/ord_schema/orm/scripts/add_datasets.py index c2ffa256..9275256b 100644 --- a/ord_schema/orm/scripts/add_datasets.py +++ b/ord_schema/orm/scripts/add_datasets.py @@ -28,7 +28,9 @@ --host= Database host [default: localhost] --port= Database port [default: 5432] --n_jobs= Number of parallel workers [default: 1] + --debug Enable debug logging. """ +import logging import os from concurrent.futures import ProcessPoolExecutor, as_completed from glob import glob @@ -101,7 +103,9 @@ def add_rdkit(dataset_id: str) -> None: def main(**kwargs): RDLogger.DisableLog("rdApp.*") - if kwargs.get("--url"): + if kwargs["--debug"]: + get_logger(database.__name__, level=logging.DEBUG) + if kwargs["--url"]: url = kwargs["--url"] else: url = database.get_connection_string(