Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to psycopg 3 #730

Merged
merged 17 commits into from
Jul 4, 2024
Merged
14 changes: 7 additions & 7 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ jobs:
test_ord_schema:
strategy:
matrix:
os: [ubuntu-latest, macos-14]
os: ["ubuntu-latest", "macos-14"]
python-version: ["3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
env:
PGDATA: $GITHUB_WORKSPACE/rdkit-postgres
Expand All @@ -38,7 +39,7 @@ jobs:
initdb
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: ${{ matrix.python-version }}
- name: Install ord_schema
run: |
python -m pip install --upgrade pip
Expand All @@ -48,20 +49,21 @@ jobs:
shell: bash -l {0}
run: |
coverage erase
pytest -vv --cov=ord_schema --durations=0 --durations-min=1
pytest -vv --cov=ord_schema --durations=20
coverage xml
- uses: codecov/codecov-action@v1

test_notebooks:
strategy:
matrix:
os: [ ubuntu-latest, macos-14 ]
os: ["ubuntu-latest", "macos-14"]
python-version: ["3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: ${{ matrix.python-version }}
- name: Install ord_schema
run: |
python -m pip install --upgrade pip
Expand All @@ -80,8 +82,6 @@ jobs:
python-version: '3.10'
- name: Install protoc
run: |
python -m pip install --upgrade pip
python -m pip install wheel
mkdir protoc
cd protoc
wget https://github.com/protocolbuffers/protobuf/releases/download/v22.3/protoc-22.3-linux-x86_64.zip
Expand Down
12 changes: 7 additions & 5 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,19 @@ else
fi
# Format python.
if ! command -v black &> /dev/null; then
pip install black
pip install black[jupyter]
fi
black "${ROOT_DIR}"
if ! command -v isort &> /dev/null; then
pip install isort
fi
isort "${ROOT_DIR}"
# Format proto.
if command -v clang-format-10 &> /dev/null; then
find "${ROOT_DIR}" -name '*.proto' -exec clang-format-10 -i --style=file {} +
elif command -v clang-format &> /dev/null; then
if command -v clang-format &> /dev/null; then
# NOTE(kearnes): Make sure you have version 10 or higher!
find "${ROOT_DIR}" -name '*.proto' -exec clang-format -i --style=file {} +
else
echo "Please install clang-format:"
echo " Linux: apt install clang-format-10"
echo " Linux: apt install clang-format"
echo " MacOS: brew install clang-format"
fi
2 changes: 1 addition & 1 deletion ord_schema/message_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def fetch_dataset(dataset_id: str, timeout: float = 10.0) -> dataset_pb2.Dataset
RuntimeError: If the request fails.
ValueError: If the dataset ID is invalid.
"""
from ord_schema import validations # Avoid circular import; pylint: disable=import-outside-toplevel.
from ord_schema import validations # pylint: disable=cyclic-import,import-outside-toplevel

if not validations.is_valid_dataset_id(dataset_id):
raise ValueError(f"Invalid dataset ID: {dataset_id}")
Expand Down
16 changes: 8 additions & 8 deletions ord_schema/orm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ from sqlalchemy import create_engine

from ord_schema.orm.database import prepare_database

connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}"
connection_string = f"postgresql+psycopg://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string, future=True)
prepare_database(engine)
```
Expand All @@ -135,11 +135,11 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from ord_schema.message_helpers import fetch_dataset
from ord_schema.orm.database import add_dataset, add_rdkit
from ord_schema.orm.database import add_dataset

dataset = fetch_dataset("ord_dataset-fc83743b978f4deea7d6856deacbfe53")

connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}"
connection_string = f"postgresql+psycopg://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string, future=True)
with Session(engine) as session:
add_dataset(dataset, session)
Expand Down Expand Up @@ -177,7 +177,7 @@ from sqlalchemy.orm import Session
from ord_schema.orm.mappers import Mappers
from ord_schema.proto import reaction_pb2

connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}"
connection_string = f"postgresql+psycopg://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string, future=True)
with Session(engine) as session:
query = (
Expand All @@ -202,18 +202,18 @@ from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

from ord_schema.orm.mappers import Mappers
from ord_schema.orm.rdkit_mappers import FingerprintType, RDKitMol
from ord_schema.orm.rdkit_mappers import FingerprintType, RDKitMols
from ord_schema.proto import reaction_pb2

connection_string = f"postgresql://{username}:{password}@{host}:{port}/{database}"
connection_string = f"postgresql+psycopg://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string, future=True)
with Session(engine) as session:
query = (
select(Mappers.Reaction)
.join(Mappers.ReactionInput)
.join(Mappers.Compound)
.join(RDKitMol)
.where(RDKitMol.tanimoto("c1ccccc1CCC(O)C", FingerprintType.MORGAN_BFP) > 0.5)
.join(RDKitMols)
.where(RDKitMols.tanimoto("c1ccccc1CCC(O)C", FingerprintType.MORGAN_BFP) > 0.5)
)
results = session.execute(query)
reactions = [reaction_pb2.Reaction.FromString(result[0].proto) for result in results]
Expand Down
31 changes: 20 additions & 11 deletions ord_schema/orm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

"""Pytest fixtures."""
import os
import re
from typing import Iterator

import pytest
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from testing.postgresql import Postgresql

Expand All @@ -26,19 +28,26 @@
from ord_schema.proto import dataset_pb2


@pytest.fixture
def test_session() -> Iterator[Session]:
@pytest.fixture(name="test_engine")
def test_engine_fixture() -> Iterator[Engine]:
with Postgresql() as postgres:
# See https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#module-sqlalchemy.dialects.postgresql.psycopg.
url = re.sub("postgresql://", "postgresql+psycopg://", postgres.url())
engine = create_engine(url, future=True)
yield engine


@pytest.fixture(name="test_session")
def test_session_fixture(test_engine) -> Iterator[Session]:
datasets = [
load_message(
os.path.join(os.path.dirname(__file__), "testdata", "ord-nielsen-example.pbtxt"), dataset_pb2.Dataset
)
]
with Postgresql() as postgres:
engine = create_engine(postgres.url(), future=True)
rdkit_cartridge = prepare_database(engine)
with Session(engine) as session:
for dataset in datasets:
add_dataset(dataset, session, rdkit_cartridge=rdkit_cartridge)
session.commit()
with Session(engine) as session:
yield session
rdkit_cartridge = prepare_database(test_engine)
with Session(test_engine) as session:
for dataset in datasets:
add_dataset(dataset, session, rdkit_cartridge=rdkit_cartridge)
session.commit()
with Session(test_engine) as session:
yield session
24 changes: 12 additions & 12 deletions ord_schema/orm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ord_schema.logging import get_logger
from ord_schema.orm.mappers import Base, Mappers, from_proto
from ord_schema.orm.rdkit_mappers import CString, FingerprintType, RDKitMol, RDKitReaction
from ord_schema.orm.rdkit_mappers import CString, FingerprintType, RDKitMols, RDKitReactions
from ord_schema.proto import dataset_pb2

logger = get_logger(__name__)
Expand All @@ -35,7 +35,7 @@ def get_connection_string(
database: str, username: str, password: str, host: str = "localhost", port: int = 5432
) -> str:
"""Creates an SQLAlchemy connection string."""
return f"postgresql://{username}:{password}@{host}:{port}/{database}?client_encoding=utf-8"
return f"postgresql+psycopg://{username}:{password}@{host}:{port}/{database}"


def prepare_database(engine: Engine) -> bool:
Expand Down Expand Up @@ -119,8 +119,8 @@ def update_rdkit_tables(dataset_id: str, session: Session) -> None:
def _update_rdkit_reactions(dataset_id: str, session: Session) -> None:
"""Updates the RDKit reactions table."""
logger.info("Updating RDKit reactions")
assert hasattr(RDKitReaction, "__table__") # Type hint.
table = RDKitReaction.__table__
assert hasattr(RDKitReactions, "__table__") # Type hint.
table = RDKitReactions.__table__
start = time.time()
session.execute(
insert(table)
Expand All @@ -144,8 +144,8 @@ def _update_rdkit_reactions(dataset_id: str, session: Session) -> None:
def _update_rdkit_mols(dataset_id: str, session: Session) -> None:
"""Updates the RDKit mols table."""
logger.info("Updating RDKit mols")
assert hasattr(RDKitMol, "__table__") # Type hint.
table = RDKitMol.__table__
assert hasattr(RDKitMols, "__table__") # Type hint.
table = RDKitMols.__table__
start = time.time()
# NOTE(skearnes): This join path will not include non-input compounds like workups, internal standards, etc.
session.execute(
Expand Down Expand Up @@ -201,8 +201,8 @@ def update_rdkit_ids(dataset_id: str, session: Session) -> None:
start = time.time()
# Update Reaction.
query = session.execute(
select(Mappers.Reaction.id, RDKitReaction.id)
.join(RDKitReaction, Mappers.Reaction.reaction_smiles == RDKitReaction.reaction_smiles)
select(Mappers.Reaction.id, RDKitReactions.id)
.join(RDKitReactions, Mappers.Reaction.reaction_smiles == RDKitReactions.reaction_smiles)
.join(Mappers.Dataset)
.where(Mappers.Dataset.dataset_id == dataset_id)
)
Expand All @@ -212,8 +212,8 @@ def update_rdkit_ids(dataset_id: str, session: Session) -> None:
session.execute(update(Mappers.Reaction), updates)
# Update Compound.
query = session.execute(
select(Mappers.Compound.id, RDKitMol.id)
.join(RDKitMol, Mappers.Compound.smiles == RDKitMol.smiles)
select(Mappers.Compound.id, RDKitMols.id)
.join(RDKitMols, Mappers.Compound.smiles == RDKitMols.smiles)
.join(Mappers.ReactionInput)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
Expand All @@ -225,8 +225,8 @@ def update_rdkit_ids(dataset_id: str, session: Session) -> None:
session.execute(update(Mappers.Compound), updates)
# Update ProductCompound.
query = session.execute(
select(Mappers.ProductCompound.id, RDKitMol.id)
.join(RDKitMol, Mappers.ProductCompound.smiles == RDKitMol.smiles)
select(Mappers.ProductCompound.id, RDKitMols.id)
.join(RDKitMols, Mappers.ProductCompound.smiles == RDKitMols.smiles)
.join(Mappers.ReactionOutcome)
.join(Mappers.Reaction)
.join(Mappers.Dataset)
Expand Down
4 changes: 2 additions & 2 deletions ord_schema/orm/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def build_mapper( # pylint: disable=too-many-branches
attrs["proto"] = Column(LargeBinary, nullable=False)
attrs["reaction_smiles"] = Column(Text, index=True)
attrs["rdkit_reaction_id"] = Column(Integer, ForeignKey("rdkit.reactions.id"))
attrs["rdkit_reaction"] = relationship("RDKitReaction")
attrs["rdkit_reaction"] = relationship("RDKitReactions")
elif message_type in {reaction_pb2.Compound, reaction_pb2.ProductCompound}:
attrs["smiles"] = Column(Text, index=True)
attrs["rdkit_mol_id"] = Column(Integer, ForeignKey("rdkit.mols.id"))
attrs["rdkit_mol"] = relationship("RDKitMol")
attrs["rdkit_mol"] = relationship("RDKitMols")
elif message_type in {reaction_pb2.CompoundPreparation, reaction_pb2.CrudeComponent}:
# Add foreign key to reaction.reaction_id.
kwargs = {}
Expand Down
24 changes: 12 additions & 12 deletions ord_schema/orm/rdkit_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def rdkit_cartridge() -> bool:
return bool(strtobool(os.environ.get("ORD_POSTGRES_RDKIT", "1")))


class _RDKitMol(UserDefinedType):
class RDKitMol(UserDefinedType):
"""https://github.com/rdkit/rdkit/blob/master/Code/PgSQL/rdkit/rdkit.sql.in#L4."""

cache_ok = True
Expand All @@ -54,7 +54,7 @@ def get_col_spec(self, **kwargs):
return "mol" if rdkit_cartridge() else "bytea"


class _RDKitReaction(UserDefinedType):
class RDKitReaction(UserDefinedType):
"""https://github.com/rdkit/rdkit/blob/master/Code/PgSQL/rdkit/rdkit.sql.in#L129."""

cache_ok = True
Expand All @@ -69,7 +69,7 @@ def get_col_spec(self, **kwargs):
return "reaction" if rdkit_cartridge() else "bytea"


class _RDKitBfp(UserDefinedType):
class RDKitBfp(UserDefinedType):
"""https://github.com/rdkit/rdkit/blob/master/Code/PgSQL/rdkit/rdkit.sql.in#L81."""

cache_ok = True
Expand All @@ -84,7 +84,7 @@ def get_col_spec(self, **kwargs):
return "bfp" if rdkit_cartridge() else "bytea"


class _RDKitSfp(UserDefinedType):
class RDKitSfp(UserDefinedType):
"""https://github.com/rdkit/rdkit/blob/master/Code/PgSQL/rdkit/rdkit.sql.in#L105."""

cache_ok = True
Expand Down Expand Up @@ -125,15 +125,15 @@ def __call__(self, *args, **kwargs):
return self.value(*args, **kwargs)


class RDKitMol(Base):
class RDKitMols(Base):
"""Table for storing compound structures and associated RDKit cartridge data."""

__tablename__ = "mols"
id = Column(Integer, primary_key=True)
smiles = Column(Text, index=True, unique=True)
mol = Column(_RDKitMol)
morgan_bfp = Column(_RDKitBfp)
morgan_sfp = Column(_RDKitSfp)
mol = Column(RDKitMol)
morgan_bfp = Column(RDKitBfp)
morgan_sfp = Column(RDKitSfp)

__table_args__ = (
Index("mol_index", "mol", postgresql_using="gist"),
Expand All @@ -144,24 +144,24 @@ class RDKitMol(Base):

@classmethod
def tanimoto(cls, other: str, fp_type: FingerprintType = FingerprintType.MORGAN_BFP) -> ColumnElement[float]:
return func.tanimoto_sml(getattr(cls, fp_type.name.lower()), fp_type(other))
return func.tanimoto_sml(getattr(cls, fp_type.name.lower()), fp_type(cast(other, RDKitMol)))

@classmethod
def contains_substructure(cls, pattern: str) -> ColumnElement[bool]:
return func.substruct(cls.mol, pattern)
return func.substruct(cls.mol, cast(pattern, RDKitMol))

@classmethod
def matches_smarts(cls, pattern: str) -> ColumnElement[bool]:
return func.substruct(cls.mol, func.qmol_from_smarts(cast(pattern, CString)))


class RDKitReaction(Base):
class RDKitReactions(Base):
"""Table for storing reaction objects and associated RDKit cartridge data."""

__tablename__ = "reactions"
id = Column(Integer, primary_key=True)
reaction_smiles = Column(Text, index=True, unique=True)
reaction = Column(_RDKitReaction)
reaction = Column(RDKitReaction)

__table_args__ = (
Index("reaction_index", "reaction", postgresql_using="gist"),
Expand Down
Loading
Loading