Skip to content

Commit

Permalink
Faster rdkit operations in the ORM (#733)
Browse files Browse the repository at this point in the history
* temp tables

* fix query

* comment

* id queries

* --debug

* lint

* skip existing associations

* skip existing associations

* no temp table for reactions

* remove parens

* try with no temp table?

* lint

* alias
  • Loading branch information
skearnes authored Jul 11, 2024
1 parent 0e55df9 commit 61a2eeb
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 105 deletions.
221 changes: 117 additions & 104 deletions ord_schema/orm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
import time
from unittest.mock import patch

from sqlalchemy import cast, delete, func, select, text, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import delete, select, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import NotSupportedError, OperationalError
from sqlalchemy.orm import Session

from ord_schema.logging import get_logger
from ord_schema.orm.mappers import Base, Mappers, from_proto
from ord_schema.orm.rdkit_mappers import CString, FingerprintType, RDKitMols, RDKitReactions
from ord_schema.proto import dataset_pb2

logger = get_logger(__name__)
Expand Down Expand Up @@ -110,132 +108,147 @@ def delete_dataset(dataset_id: str, session: Session) -> None:

def update_rdkit_tables(dataset_id: str, session: Session) -> None:
"""Updates RDKit PostgreSQL cartridge data."""
logger.debug(f"Updating RDKit tables for {dataset_id=}")
_update_rdkit_reactions(dataset_id, session)
_update_rdkit_mols(dataset_id, session)


def _update_rdkit_reactions(dataset_id: str, session: Session) -> None:
"""Updates the RDKit reactions table."""
logger.debug("Updating RDKit reactions")
assert hasattr(RDKitReactions, "__table__") # Type hint.
table = RDKitReactions.__table__
start = time.time()
session.execute(
insert(table)
.from_select(
["reaction_smiles"],
select(Mappers.Reaction.reaction_smiles)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id, Mappers.Reaction.reaction_smiles.is_not(None))
.distinct(),
)
.on_conflict_do_nothing(index_elements=["reaction_smiles"])
)
logger.debug(f"Updating reaction SMILES took {time.time() - start:g}s")
start = time.time()
session.execute(
update(table)
.where(table.c.reaction.is_(None))
.values(reaction=func.reaction_from_smiles(cast(table.c.reaction_smiles, CString)))
result = session.execute(
text(
"""
INSERT INTO rdkit.reactions (reaction_smiles, reaction)
SELECT reaction_smiles, reaction_from_smiles(reaction_smiles::cstring)
FROM (
SELECT reaction_smiles
FROM ord.reaction
JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id
WHERE ord.dataset.dataset_id = :dataset_id
AND ord.reaction.rdkit_reaction_id IS NULL
EXCEPT
SELECT reaction_smiles
FROM rdkit.reactions
) subquery
"""
),
{"dataset_id": dataset_id},
)
logger.debug(f"reaction_from_smiles took {time.time() - start:g}s")
logger.debug(f"Updating reactions took {time.time() - start:g}s ({result.rowcount} rows)")


def _update_rdkit_mols(dataset_id: str, session: Session) -> None:
"""Updates the RDKit mols table."""
logger.debug("Updating RDKit mols")
assert hasattr(RDKitMols, "__table__") # Type hint.
table = RDKitMols.__table__
start = time.time()
# NOTE(skearnes): This join path does not include non-input compounds like workups, internal standards, etc.
session.execute(
insert(table)
.from_select(
["smiles"],
select(Mappers.Compound.smiles)
.join(Mappers.ReactionInput)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
.where(
Mappers.Dataset.dataset_id == dataset_id,
Mappers.Compound.smiles.is_not(None),
# See https://github.com/open-reaction-database/ord-schema/issues/672.
Mappers.Compound.smiles.not_like("%[Ti+5]%"),
)
.distinct(),
)
.on_conflict_do_nothing(index_elements=["smiles"])
)
session.execute(
insert(table)
.from_select(
["smiles"],
select(Mappers.ProductCompound.smiles)
.join(Mappers.ReactionOutcome)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id, Mappers.ProductCompound.smiles.is_not(None))
.distinct(),
)
.on_conflict_do_nothing(index_elements=["smiles"])
result = session.execute(
text(
"""
INSERT INTO rdkit.mols (smiles, mol, morgan_bfp, morgan_sfp)
SELECT smiles, mol, morgan_bfp, morgan_sfp
FROM (
SELECT smiles, mol, morganbv_fp(mol) AS morgan_bfp, morgan_fp(mol) AS morgan_sfp
FROM (
SELECT smiles, mol_from_smiles(smiles::cstring) AS mol
FROM (
SELECT smiles
-- NOTE(skearnes): This join path does not include non-input compounds like workups,
-- internal standards, etc.
FROM ord.compound
JOIN ord.reaction_input ON ord.compound.reaction_input_id = ord.reaction_input.id
JOIN ord.reaction ON ord.reaction_input.reaction_id = ord.reaction.id
JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id
WHERE ord.dataset.dataset_id = :dataset_id
AND ord.compound.rdkit_mol_id IS NULL
UNION
SELECT smiles
FROM ord.product_compound
JOIN ord.reaction_outcome
ON ord.product_compound.reaction_outcome_id = ord.reaction_outcome.id
JOIN ord.reaction ON ord.reaction_outcome.reaction_id = ord.reaction.id
JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id
WHERE ord.dataset.dataset_id = :dataset_id
AND ord.product_compound.rdkit_mol_id IS NULL
EXCEPT
SELECT smiles
FROM rdkit.mols
) smiles_subquery
-- See https://github.com/open-reaction-database/ord-schema/issues/672.
WHERE smiles NOT LIKE '%[Ti+5]%'
) mol_subquery
) fp_subquery
ON CONFLICT (smiles) DO NOTHING
"""
),
{"dataset_id": dataset_id},
)
logger.debug(f"Updating SMILES took {time.time() - start:g}s")
start = time.time()
session.execute(
update(table).where(table.c.mol.is_(None)).values(mol=func.mol_from_smiles(cast(table.c.smiles, CString)))
)
logger.debug(f"mol_from_smiles took {time.time() - start:g}s")
logger.debug("Updating fingerprints")
for fp_type in FingerprintType:
start = time.time()
column = fp_type.name.lower()
session.execute(
update(table)
.where(getattr(table.c, column).is_(None), table.c.mol.is_not(None))
.values(**{column: fp_type(table.c.mol)})
)
logger.debug(f"{fp_type} took {time.time() - start:g}s")
logger.debug(f"Updating mols took {time.time() - start:g}s ({result.rowcount} rows)")


def update_rdkit_ids(dataset_id: str, session: Session) -> None:
"""Updates RDKit reaction and mol ID associations in the ORD tables."""
logger.debug("Updating RDKit ID associations")
start = time.time()
# Update Reaction.
query = session.execute(
select(Mappers.Reaction.id, RDKitReactions.id)
.join(RDKitReactions, Mappers.Reaction.reaction_smiles == RDKitReactions.reaction_smiles)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
session.execute(
text(
"""
UPDATE ord.reaction
SET rdkit_reaction_id = subquery.rdkit_reaction_id
FROM (
SELECT ord.reaction.id, rdkit.reactions.id AS rdkit_reaction_id
FROM ord.reaction
JOIN rdkit.reactions USING (reaction_smiles)
JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id
WHERE ord.dataset.dataset_id = :dataset_id
AND ord.reaction.rdkit_reaction_id IS NULL
) AS subquery
WHERE ord.reaction.id = subquery.id
"""
),
{"dataset_id": dataset_id},
)
updates = []
for ord_id, rdkit_id in query.fetchall():
updates.append({"id": ord_id, "rdkit_reaction_id": rdkit_id})
session.execute(update(Mappers.Reaction), updates)
# Update Compound.
query = session.execute(
select(Mappers.Compound.id, RDKitMols.id)
.join(RDKitMols, Mappers.Compound.smiles == RDKitMols.smiles)
.join(Mappers.ReactionInput)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
session.execute(
text(
"""
UPDATE ord.compound
SET rdkit_mol_id = subquery.rdkit_mol_id
FROM (
SELECT ord.compound.id, rdkit.mols.id AS rdkit_mol_id
FROM ord.compound
JOIN rdkit.mols USING (smiles)
JOIN ord.reaction_input ON ord.compound.reaction_input_id = ord.reaction_input.id
JOIN ord.reaction ON ord.reaction_input.reaction_id = ord.reaction.id
JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id
WHERE ord.dataset.dataset_id = :dataset_id
AND ord.compound.rdkit_mol_id IS NULL
) AS subquery
WHERE ord.compound.id = subquery.id
"""
),
{"dataset_id": dataset_id},
)
updates = []
for ord_id, rdkit_id in query.fetchall():
updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id})
session.execute(update(Mappers.Compound), updates)
# Update ProductCompound.
query = session.execute(
select(Mappers.ProductCompound.id, RDKitMols.id)
.join(RDKitMols, Mappers.ProductCompound.smiles == RDKitMols.smiles)
.join(Mappers.ReactionOutcome)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
session.execute(
text(
"""
UPDATE ord.product_compound
SET rdkit_mol_id = subquery.rdkit_mol_id
FROM (
SELECT ord.product_compound.id, rdkit.mols.id AS rdkit_mol_id
FROM ord.product_compound
JOIN rdkit.mols USING (smiles)
JOIN ord.reaction_outcome ON ord.product_compound.reaction_outcome_id = ord.reaction_outcome.id
JOIN ord.reaction ON ord.reaction_outcome.reaction_id = ord.reaction.id
JOIN ord.dataset ON ord.reaction.dataset_id = ord.dataset.id
WHERE ord.dataset.dataset_id = :dataset_id
AND ord.product_compound.rdkit_mol_id IS NULL
) AS subquery
WHERE ord.product_compound.id = subquery.id
"""
),
{"dataset_id": dataset_id},
)
updates = []
for ord_id, rdkit_id in query.fetchall():
updates.append({"id": ord_id, "rdkit_mol_id": rdkit_id})
session.execute(update(Mappers.ProductCompound), updates)
logger.debug(f"Updating RDKit IDs took {time.time() - start:g}s")
6 changes: 5 additions & 1 deletion ord_schema/orm/scripts/add_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
--host=<str> Database host [default: localhost]
--port=<int> Database port [default: 5432]
--n_jobs=<int> Number of parallel workers [default: 1]
--debug Enable debug logging.
"""
import logging
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from glob import glob
Expand Down Expand Up @@ -101,7 +103,9 @@ def add_rdkit(dataset_id: str) -> None:

def main(**kwargs):
RDLogger.DisableLog("rdApp.*")
if kwargs.get("--url"):
if kwargs["--debug"]:
get_logger(database.__name__, level=logging.DEBUG)
if kwargs["--url"]:
url = kwargs["--url"]
else:
url = database.get_connection_string(
Expand Down

0 comments on commit 61a2eeb

Please sign in to comment.