Skip to content

Commit

Permalink
feat: adding "tune run-optuna" command (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe authored Sep 18, 2023
1 parent 7df38b8 commit 6cc753b
Show file tree
Hide file tree
Showing 24 changed files with 245 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tests/data/** filter=lfs diff=lfs merge=lfs -text
data/param_opt/* filter=lfs diff=lfs merge=lfs -text
data/** filter=lfs diff=lfs merge=lfs -text
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@ This is a re-implementation of the [CADA](https://github.com/Chengyao-Peng/CADA)
- Discussion Forum: https://github.com/bihealth/cada-prio/discussions
- Bug Reports: https://github.com/bihealth/cada-prio/issues

## Running Hyperparameter Tuning

Install with `tune` feature enabled:

```
pip install cada-prio[tune]
```

Run tuning, e.g., on the "classic" model.
Thanks to [optuna](https://optuna.org/), you can run this in parallel as long as the database is shared.
Each run will use 4 CPUs in the example below and perform 1 trial.

```
cada-prio tune run-optuna \
sqlite:///local_data/cada-tune.sqlite \
--path-hgnc-json data/classic/hgnc_complete_set.json \
--path-hpo-genes-to-phenotype data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt \
--path-hpo-obo data/classic/hp.obo \
--path-clinvar-phenotype-links data/classic/cases_train.jsonl \
--path-validation-links data/classic/cases_validate.jsonl \
--n-trials 1 \
--cpus=4
```

## Managing GitHub Project with Terraform

Expand Down
122 changes: 115 additions & 7 deletions cada_prio/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
"""Console script for CADA"""

import json
import logging
import os
import sys
import tempfile
import typing

import cattr
import click
import logzero

from cada_prio import _version, inspection, param_opt, predict, train_model
try:
import optuna

_ = optuna
HAVE_OPTUNA = True
except ImportError:
HAVE_OPTUNA = False

# Lower the update interval of tqdm to 5 seconds if stdout is not a TTY.
if not sys.stdout.isatty():
os.environ["TQDM_MININTERVAL"] = "5"
from cada_prio import _version, inspection, param_opt, predict, train_model


@click.group()
Expand Down Expand Up @@ -117,17 +123,21 @@ def cli_utils():


@cli_utils.command("dump-graph")
@click.option(
"--hgnc-to-entrez/--no-hgnc-to-entrez", help="enable HGNC to Entrez mapping", default=True
)
@click.argument("path_graph", type=str)
@click.argument("path_hgnc_info", type=str)
@click.pass_context
def cli_dump_graph(
ctx: click.Context,
path_graph: str,
path_hgnc_info: str,
hgnc_to_entrez: bool,
):
"""dump graph edges for debugging"""
ctx.ensure_object(dict)
inspection.dump_graph(path_graph, path_hgnc_info)
inspection.dump_graph(path_graph, path_hgnc_info, hgnc_to_entrez)


@cli.group("tune")
Expand Down Expand Up @@ -171,7 +181,7 @@ def cli_tune():
)
@click.option("--cpus", type=int, help="number of CPUs to use", default=1)
@click.pass_context
def cli_param_opt(
def cli_train_eval(
ctx: click.Context,
path_out: str,
path_hgnc_json: str,
Expand Down Expand Up @@ -203,3 +213,101 @@ def cli_param_opt(
seed=seed,
cpus=cpus,
)


if HAVE_OPTUNA:

@cli_tune.command("run-optuna")
@click.argument("storage", type=str)
@click.option("--n-trials", type=int, help="number of trials to run; default: 100", default=100)
@click.option(
"--study-name",
type=str,
help="name of Optuna study; default: cada-tune",
default="cada-tune",
)
@click.option("--path-hgnc-json", type=str, help="path to HGNC JSON", required=True)
@click.option(
"--path-hpo-genes-to-phenotype",
type=str,
help="path to genes_to_phenotype.txt file",
required=True,
)
@click.option("--path-hpo-obo", type=str, help="path HPO OBO file", required=True)
@click.option(
"--path-clinvar-phenotype-links",
type=str,
help="path to ClinVar phenotype links JSONL",
required=True,
)
@click.option(
"--fraction-links",
type=float,
help="fraction of links to add to the graph (conflicts with --path-validation-links)",
)
@click.option(
"--path-validation-links",
type=str,
help="path to validation links JSONL (conflicts with --fraction-links)",
)
@click.option(
"--seed",
type=int,
help="seed for random number generator",
)
@click.option("--cpus", type=int, help="number of CPUs to use", default=1)
def cli_run_optuna(
storage: str,
study_name: str,
n_trials: int,
path_hgnc_json: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_clinvar_phenotype_links: str,
fraction_links: typing.Optional[float],
path_validation_links: typing.Optional[str],
seed: typing.Optional[int],
cpus: int,
):
"""run hyperparameter tuning"""

def objective(trial: optuna.trial.BaseTrial) -> float:
with tempfile.TemporaryDirectory() as tmpdir, open(
f"{tmpdir}/params.json", "wt"
) as tmpf:
json.dump(
cattr.unstructure(
train_model.EmbeddingParams(
dimensions=trial.suggest_int("dimensions", low=100, high=500, step=10),
walk_length=trial.suggest_int("walk_length", low=1, high=100),
p=trial.suggest_float("p", low=0.1, high=2.5),
q=trial.suggest_float("q", low=0.0, high=1.0),
num_walks=trial.suggest_int("num_walks", low=10, high=50),
window=trial.suggest_int("window", low=4, high=8),
epochs=trial.suggest_int("epochs", low=3, high=6),
use_skipgram=trial.suggest_categorical("use_skipgram", [True, False]),
min_count=trial.suggest_int("batch_words", low=1, high=5),
batch_words=trial.suggest_int("batch_words", low=3, high=6),
)
),
tmpf,
indent=2,
)
tmpf.flush()
result = param_opt.train_and_validate(
path_out=f"{tmpdir}/out",
path_hgnc_json=path_hgnc_json,
path_hpo_genes_to_phenotype=path_hpo_genes_to_phenotype,
path_hpo_obo=path_hpo_obo,
path_clinvar_phenotype_links=path_clinvar_phenotype_links,
fraction_links=fraction_links,
path_validation_links=path_validation_links,
path_embedding_params=f"{tmpdir}/params.json",
seed=seed,
cpus=cpus,
)
return result[100]

optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(study_name=study_name, storage=storage, load_if_exists=True)
study.optimize(objective, n_trials=n_trials)
28 changes: 23 additions & 5 deletions cada_prio/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,37 @@

import pickle

from logzero import logger
import networkx as nx

from cada_prio.predict import load_hgnc_info


def dump_graph(path_graph: str, path_hgnc_info: str):
def dump_graph(path_graph: str, path_hgnc_info: str, hgnc_to_entrez: bool):
_, hgnc_info_by_id = load_hgnc_info(path_hgnc_info)
hgnc_info_by_ncbi_gene_id = {
"Entrez:%s" % x.ncbi_gene_id: x for x in hgnc_info_by_id.values() if x.ncbi_gene_id
}
with open(path_graph, "rb") as inputf:
graph: nx.Graph = pickle.load(inputf)
for edge in sorted(graph.edges):
lhs, rhs = edge
if lhs.startswith("HGNC:"):
lhs = "Entrez:%s" % hgnc_info_by_id[lhs].ncbi_gene_id
if rhs.startswith("HGNC:"):
rhs = "Entrez:%s" % hgnc_info_by_id[rhs].ncbi_gene_id

if hgnc_to_entrez:
if lhs.startswith("HGNC:"):
lhs = "Entrez:%s" % hgnc_info_by_id[lhs].ncbi_gene_id
if rhs.startswith("HGNC:"):
rhs = "Entrez:%s" % hgnc_info_by_id[rhs].ncbi_gene_id
else: # entrez to hgnc
if lhs.startswith("Entrez:"):
if lhs not in hgnc_info_by_ncbi_gene_id:
logger.warning("no HGNC ID for Entrez ID %s", lhs)
continue
lhs = hgnc_info_by_ncbi_gene_id[lhs].hgnc_id
if rhs.startswith("Entrez:"):
if rhs not in hgnc_info_by_ncbi_gene_id:
logger.warning("no HGNC ID for Entrez ID %s", rhs)
continue
rhs = hgnc_info_by_ncbi_gene_id[rhs].hgnc_id

print(f"{lhs}\t{rhs}")
16 changes: 16 additions & 0 deletions cada_prio/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,23 @@ class EmbeddingParams:
batch_words: int = 4
#: RNG seed for embedding
seed_embedding: int = 1
#: Whether to use skipgram for model fitting instead of CBOW
use_skipgram: bool = False
#: Number of epochs for fitting
epochs: int = 5
#: RNG seed for fitting
seed_fit: int = 1


def log_graph_examples(edges: typing.List[Edge]):
edges = list(sorted(edges))
all_nodes = set(itertools.chain([e[0] for e in edges], [e[1] for e in edges]))
hgnc_nodes = list(sorted([v for v in all_nodes if v.startswith("HGNC:")]))
hpo_nodes = list(sorted([v for v in all_nodes if v.startswith("HP:")]))
logger.info(" - ten HGNC nodes: %s", hgnc_nodes[:10])
logger.info(" - ten HPO nodes: %s", hpo_nodes[:10])


def build_and_fit_model(
*,
clinvar_gen2phen,
Expand All @@ -247,6 +260,7 @@ def build_and_fit_model(
yield_gene2phen_edges(clinvar_gen2phen),
)
)
log_graph_examples(training_edges)
logger.info("- graph construction")
training_graph = nx.Graph()
training_graph.add_edges_from(training_edges)
Expand Down Expand Up @@ -277,6 +291,8 @@ def build_and_fit_model(
min_count=embedding_params.min_count,
batch_words=embedding_params.batch_words,
seed=embedding_params.seed_fit,
sg=1 if embedding_params.use_skipgram else 0,
epochs=embedding_params.epochs,
workers=cpus,
)
fit_elapsed = time.time() - fit_start
Expand Down
3 changes: 3 additions & 0 deletions data/classic/cases_test.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/cases_train.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/cases_validate.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/hgnc_complete_set.json
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/hp.obo
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/orig/cases_test.tsv
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/orig/cases_train.tsv
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/orig/cases_validate.tsv
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/orig/genes_to_phenotype.txt
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/orig/hp.obo
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/classic/transform.py
Git LFS file not shown
Empty file added local_data/.gitkeep
Empty file.
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ tqdm >=4.0
pronto >=2.5, <3.0
networkx
node2vec >=0.4.6, <0.5
gensim >=4.3.2, <5.0
uvicorn >=0.23.2
fastapi >=0.103, <0.104
python-dotenv >=1.0, <2.0
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r tune.txt
-r base.txt

black ==23.9.1
Expand Down
1 change: 1 addition & 0 deletions requirements/tune.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
optuna >=3.3.0, <3.4.0
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def parse_requirements(path):

test_requirements = parse_requirements("requirements/test.txt")
install_requirements = parse_requirements("requirements/base.txt")
tune_requirements = parse_requirements("requirements/tune.txt")


package_root = os.path.abspath(os.path.dirname(__file__))
Expand All @@ -54,6 +55,9 @@ def parse_requirements(path):
description="Phenotype-based prioritization of variants with CADA",
entry_points={"console_scripts": ["cada-prio=cada_prio.cli:cli"]},
install_requires=install_requirements,
extras_require={
"param-tuning": tune_requirements,
},
license="MIT license",
long_description=readme + "\n\n" + history,
long_description_content_type="text/markdown",
Expand Down
21 changes: 21 additions & 0 deletions tests/test_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Model quality tests"""

from cada_prio import param_opt


def test_quality(tmpdir):
"""Test quality of model built from 'classic' data from the original paper."""
result = param_opt.train_and_validate(
path_out=f"{tmpdir}/quality-model",
path_hgnc_json="data/classic/hgnc_complete_set.json",
path_hpo_genes_to_phenotype="data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt",
path_hpo_obo="data/classic/hp.obo",
path_embedding_params=None,
path_clinvar_phenotype_links="/dev/null",
path_validation_links="data/classic/cases_validate.jsonl",
fraction_links=None,
seed=1,
cpus=1,
)

assert result[100] >= 56.0, f"result={result}"

0 comments on commit 6cc753b

Please sign in to comment.