Skip to content

Commit

Permalink
Add more strict typing (#1253)
Browse files Browse the repository at this point in the history
Related to #1201, the Bioregistry codebase already started to get unruly
when I wasn't able to give consistent code quality reviews. Moving
towards making all files being strictly typed is the next big step
towards making it more sustainable. This is a first step towards that.
It focuses on making sure that students might touch has the highest
amount of checks on it
  • Loading branch information
cthoyt authored Nov 4, 2024
1 parent 4c49897 commit 9900364
Show file tree
Hide file tree
Showing 44 changed files with 477 additions and 324 deletions.
2 changes: 1 addition & 1 deletion src/bioregistry/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Command line interface for the bioregistry."""

from .cli import main
from .cli import main # type:ignore

if __name__ == "__main__":
main()
24 changes: 17 additions & 7 deletions src/bioregistry/analysis/bioregistry_diff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Given two dates, analyzes and visualizes changes in the Bioregistry."""

import json
from __future__ import annotations

import datetime
import logging
from typing import Any

import click
import matplotlib.pyplot as plt
Expand All @@ -21,7 +24,12 @@
FILE_PATH = "src/bioregistry/data/bioregistry.json"


def get_commit_before_date(date, owner=REPO_OWNER, name=REPO_NAME, branch=BRANCH):
def get_commit_before_date(
date: datetime.date,
owner: str = REPO_OWNER,
name: str = REPO_NAME,
branch: str = BRANCH,
) -> str | None:
"""Return the last commit before a given date.
:param date: The date to get the commit before.
Expand All @@ -46,7 +54,9 @@ def get_commit_before_date(date, owner=REPO_OWNER, name=REPO_NAME, branch=BRANCH
return None


def get_file_at_commit(file_path, commit_sha, owner=REPO_OWNER, name=REPO_NAME):
def get_file_at_commit(
file_path: str, commit_sha: str, owner: str = REPO_OWNER, name: str = REPO_NAME
) -> dict[str, Any]:
"""Return the content of a given file at a specific commit.
:param file_path: The file path in the repository.
Expand All @@ -61,9 +71,9 @@ def get_file_at_commit(file_path, commit_sha, owner=REPO_OWNER, name=REPO_NAME):
response.raise_for_status()
file_info = response.json()
download_url = file_info["download_url"]
file_content_response = requests.get(download_url)
file_content_response.raise_for_status()
return json.loads(file_content_response.text)
res = requests.get(download_url)
res.raise_for_status()
return res.json()


def compare_bioregistry(old_data, new_data):
Expand Down Expand Up @@ -246,7 +256,7 @@ def compare_dates(date1, date2):
:param date1: The starting date in the format YYYY-MM-DD.
:param date2: The ending date in the format YYYY-MM-DD.
"""
added, deleted, updated, update_details, old_data, new_data, all_mapping_keys = get_data(
added, deleted, updated, update_details, _old_data, _new_data, all_mapping_keys = get_data(
date1, date2
)
if added is not None and updated is not None:
Expand Down
139 changes: 76 additions & 63 deletions src/bioregistry/analysis/paper_ranking.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
"""Train a TF-IDF classifier and use it to score the relevance of new PubMed papers to the Bioregistry."""

from __future__ import annotations

import datetime
import json
from collections import defaultdict
from pathlib import Path

import click
import indra.literature.pubmed_client as pubmed_client
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from sklearn.base import ClassifierMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import matthews_corrcoef, roc_auc_score
from sklearn.model_selection import cross_val_predict, train_test_split
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from tabulate import tabulate

DIRECTORY = Path("exports/analyses/paper_ranking")
HERE = Path(__file__).parent.resolve()
ROOT = HERE.parent.parent.parent.resolve()

BIOREGISTRY_PATH = ROOT.joinpath("src", "bioregistry", "data", "bioregistry.json")

DIRECTORY = ROOT.joinpath("exports", "analyses", "paper_ranking")
DIRECTORY.mkdir(exist_ok=True, parents=True)

URL = "https://docs.google.com/spreadsheets/d/e/2PACX-1vRPtP-tcXSx8zvhCuX6fqz_\
QvHowyAoDahnkixARk9rFTe0gfBN9GfdG6qTNQHHVL0i33XGSp_nV9XM/pub?output=csv"


def load_bioregistry_json(file_path):
def load_bioregistry_json(path: Path | None = None) -> pd.DataFrame:
"""Load bioregistry data from a JSON file, extracting publication details and fetching abstracts if missing.
:param file_path: Path to the bioregistry JSON file.
:type file_path: str
:param path: Path to the bioregistry JSON file.
:return: DataFrame containing publication details.
:rtype: pd.DataFrame
"""
if path is None:
path = BIOREGISTRY_PATH
try:
with open(file_path, "r") as f:
data = json.load(f)
data = json.loads(path.read_text())
except json.JSONDecodeError as e:
click.echo(f"JSONDecodeError: {e.msg}")
click.echo(f"Error at line {e.lineno}, column {e.colno}")
Expand All @@ -58,67 +68,62 @@ def load_bioregistry_json(file_path):
if pub["pubmed"] in fetched_metadata:
pub["abstract"] = fetched_metadata[pub["pubmed"]].get("abstract", "")

click.echo(f"Got {len(publications)} publications from the bioregistry")
click.echo(f"Got {len(publications):;} publications from the bioregistry")

return pd.DataFrame(publications)


def fetch_pubmed_papers():
def fetch_pubmed_papers() -> pd.DataFrame:
"""Fetch PubMed papers from the last 30 days using specific search terms.
:return: DataFrame containing PubMed paper details.
:rtype: pd.DataFrame
"""
click.echo("Starting fetch_pubmed_papers")

search_terms = ["database", "ontology", "resource", "vocabulary", "nomenclature"]
paper_to_terms = {}
paper_to_terms: defaultdict[str, list[str]] = defaultdict(list)

for term in search_terms:
pmids = pubmed_client.get_ids(term, use_text_word=True, reldate=30)
for pmid in pmids:
if pmid in paper_to_terms:
paper_to_terms[pmid].append(term)
else:
paper_to_terms[pmid] = [term]
pubmed_ids = pubmed_client.get_ids(term, use_text_word=True, reldate=30)
for pubmed_id in pubmed_ids:
paper_to_terms[pubmed_id].append(term)

all_pmids = list(paper_to_terms.keys())
click.echo(f"{len(all_pmids)} PMIDs found")
click.echo(f"{len(all_pmids):;} articles found")
if not all_pmids:
click.echo(f"No PMIDs found for the last 30 days with the search terms: {search_terms}")
click.echo(f"No articles found for the last 30 days with the search terms: {search_terms}")
return pd.DataFrame()

papers = {}
for chunk in [all_pmids[i : i + 200] for i in range(0, len(all_pmids), 200)]:
papers.update(pubmed_client.get_metadata_for_ids(chunk, get_abstracts=True))

records = []
for pmid, paper in papers.items():
for pubmed_id, paper in papers.items():
title = paper.get("title")
abstract = paper.get("abstract", "")

if title and abstract:
records.append(
{
"pubmed": pmid,
"pubmed": pubmed_id,
"title": title,
"abstract": abstract,
"year": paper.get("publication_date", {}).get("year"),
"search_terms": paper_to_terms.get(pmid),
"search_terms": paper_to_terms.get(pubmed_id),
}
)

click.echo(f"{len(records)} records fetched from PubMed")
click.echo(f"{len(records):,} records fetched from PubMed")
return pd.DataFrame(records)


def load_curation_data():
def load_curation_data() -> pd.DataFrame:
"""Download and load curation data from a Google Sheets URL.
:return: DataFrame containing curated publication details.
:rtype: pd.DataFrame
"""
click.echo("Downloading curation")
click.echo("Downloading curation sheet")
df = pd.read_csv(URL)
df["label"] = df["relevant"].map(_map_labels)
df = df[["pubmed", "title", "abstract", "label"]]
Expand All @@ -136,13 +141,11 @@ def load_curation_data():
return df


def _map_labels(s: str):
def _map_labels(s: str) -> int | None:
"""Map labels to binary values.
:param s: Label value.
:type s: str
:return: Mapped binary label value.
:rtype: int
"""
if s in {"1", "1.0", 1}:
return 1
Expand All @@ -151,15 +154,15 @@ def _map_labels(s: str):
return None


def train_classifiers(x_train, y_train):
Classifiers = list[tuple[str, ClassifierMixin]]


def train_classifiers(x_train: NDArray[np.float64], y_train: NDArray[np.str_]) -> Classifiers:
"""Train multiple classifiers on the training data.
:param x_train: Training features.
:type x_train: array-like
:param y_train: Training labels.
:type y_train: array-like
:return: List of trained classifiers.
:rtype: list
"""
classifiers = [
("rf", RandomForestClassifier()),
Expand All @@ -173,17 +176,15 @@ def train_classifiers(x_train, y_train):
return classifiers


def generate_meta_features(classifiers, x_train, y_train):
def generate_meta_features(
classifiers: Classifiers, x_train: NDArray[np.float64], y_train: NDArray[np.str_]
) -> pd.DataFrame:
"""Generate meta-features for training a meta-classifier using cross-validation predictions.
:param classifiers: List of trained classifiers.
:type classifiers: list
:param x_train: Training features.
:type x_train: array-like
:param y_train: Training labels.
:type y_train: array-like
:return: DataFrame containing meta-features.
:rtype: pd.DataFrame
"""
meta_features = pd.DataFrame()
for name, clf in classifiers:
Expand All @@ -197,42 +198,42 @@ def generate_meta_features(classifiers, x_train, y_train):
return meta_features


def evaluate_meta_classifier(meta_clf, x_test_meta, y_test):
def evaluate_meta_classifier(
meta_clf: ClassifierMixin, x_test_meta: NDArray[np.float64], y_test: NDArray[np.str_]
) -> tuple[float, float]:
"""Evaluate meta-classifier using MCC and AUC-ROC scores.
:param meta_clf: Trained meta-classifier.
:type meta_clf: classifier
:param x_test_meta: Test meta-features.
:type x_test_meta: array-like
:param y_test: Test labels.
:type y_test: array-like
:return: MCC and AUC-ROC scores.
:rtype: tuple
"""
y_pred = meta_clf.predict(x_test_meta)
mcc = matthews_corrcoef(y_test, y_pred)
roc_auc = roc_auc_score(y_test, meta_clf.predict_proba(x_test_meta)[:, 1])
return mcc, roc_auc


def truncate_text(text, max_length):
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a specified maximum length."""
# FIXME replace with builtin textwrap function
return text if len(text) <= max_length else text[:max_length] + "..."


def predict_and_save(df, vectorizer, classifiers, meta_clf, filename):
def predict_and_save(
df: pd.DataFrame,
vectorizer: TfidfVectorizer,
classifiers: Classifiers,
meta_clf: ClassifierMixin,
filename: str | Path,
) -> None:
"""Predict and save scores for new data using trained classifiers and meta-classifier.
:param df: DataFrame containing new data.
:type df: pd.DataFrame
:param vectorizer: Trained TF-IDF vectorizer.
:type vectorizer: TfidfVectorizer
:param classifiers: List of trained classifiers.
:type classifiers: list
:param meta_clf: Trained meta-classifier.
:type meta_clf: classifier
:param filename: Filename to save the predictions.
:type filename: str
"""
x_meta = pd.DataFrame()
x_transformed = vectorizer.transform(df.title + " " + df.abstract)
Expand All @@ -249,23 +250,35 @@ def predict_and_save(df, vectorizer, classifiers, meta_clf, filename):
click.echo(f"Wrote predicted scores to {DIRECTORY.joinpath(filename)}")


def _first_of_month() -> str:
today = datetime.date.today()
return datetime.date(today.year, today.month, 1).isoformat()


@click.command()
@click.option(
"--bioregistry-file",
default="src/bioregistry/data/bioregistry.json",
type=Path,
help="Path to the bioregistry.json file",
)
@click.option("--start-date", required=True, help="Start date of the period")
@click.option("--end-date", required=True, help="End date of the period")
def main(bioregistry_file, start_date, end_date):
@click.option(
"--start-date",
required=True,
help="Start date of the period",
default=_first_of_month,
)
@click.option(
"--end-date",
required=True,
help="End date of the period",
default=lambda x: datetime.date.today().isoformat(),
)
def main(bioregistry_file: Path, start_date: str, end_date: str) -> None:
"""Load data, train classifiers, evaluate models, and predict new data.
:param bioregistry_file: Path to the bioregistry JSON file.
:type bioregistry_file: str
:param start_date: The start date of the period for which papers are being ranked.
:type start_date: str
:param end_date: The end date of the period for which papers are being ranked.
:type end_date: str
"""
publication_df = load_bioregistry_json(bioregistry_file)
curation_df = load_curation_data()
Expand Down Expand Up @@ -295,7 +308,7 @@ def main(bioregistry_file, start_date, end_date):
try:
mcc = matthews_corrcoef(y_test, y_pred)
except ValueError as e:
click.secho(f"{clf} failed to calculate MCC: {e}", fg="yellow")
click.secho(f"{clf} failed to calculate MCC: {e:.2f}", fg="yellow")
mcc = None
try:
if hasattr(clf, "predict_proba"):
Expand All @@ -310,7 +323,7 @@ def main(bioregistry_file, start_date, end_date):
scores.append((name, mcc or float("nan"), roc_auc or float("nan")))

evaluation_df = pd.DataFrame(scores, columns=["classifier", "mcc", "auc_roc"]).round(3)
click.echo(tabulate(evaluation_df, showindex=False, headers=evaluation_df.columns))
click.echo(evaluation_df.to_markdown(index=False))

meta_features = generate_meta_features(classifiers, x_train, y_train)
meta_clf = LogisticRegression()
Expand All @@ -323,8 +336,8 @@ def main(bioregistry_file, start_date, end_date):
else:
x_test_meta[name] = clf.decision_function(x_test)

mcc, roc_auc = evaluate_meta_classifier(meta_clf, x_test_meta, y_test)
click.echo(f"Meta-Classifier MCC: {mcc}, AUC-ROC: {roc_auc}")
mcc, roc_auc = evaluate_meta_classifier(meta_clf, x_test_meta.to_numpy(), y_test)
click.echo(f"Meta-Classifier MCC: {mcc:.2f}, AUC-ROC: {roc_auc:.2f}")
new_row = {"classifier": "meta_classifier", "mcc": mcc, "auc_roc": roc_auc}
evaluation_df = pd.concat([evaluation_df, pd.DataFrame([new_row])], ignore_index=True)

Expand All @@ -349,7 +362,7 @@ def main(bioregistry_file, start_date, end_date):
.sort_values("rf_importance", ascending=False, key=abs)
.round(4)
)
click.echo(tabulate(importances_df.head(15), showindex=False, headers=importances_df.columns))
click.echo(importances_df.head(15).to_markdown(index=False))

importance_path = DIRECTORY.joinpath("importances.tsv")
click.echo(f"Writing feature (word) importances to {importance_path}")
Expand Down
Loading

0 comments on commit 9900364

Please sign in to comment.