From 6cc753b3b4f92aa75d961c3cf314e097d174ede0 Mon Sep 17 00:00:00 2001 From: Manuel Holtgrewe Date: Mon, 18 Sep 2023 13:39:10 +0200 Subject: [PATCH] feat: adding "tune run-optuna" command (#23) --- .gitattributes | 2 +- README.md | 23 ++++ cada_prio/cli.py | 122 +++++++++++++++++- cada_prio/inspection.py | 28 +++- cada_prio/train_model.py | 16 +++ data/classic/cases_test.jsonl | 3 + data/classic/cases_train.jsonl | 3 + data/classic/cases_validate.jsonl | 3 + ..._to_phenotype.all_source_all_freqs_etc.txt | 3 + data/classic/hgnc_complete_set.json | 3 + data/classic/hp.obo | 3 + ...CES_ALL_FREQUENCIES_genes_to_phenotype.txt | 3 + data/classic/orig/cases_test.tsv | 3 + data/classic/orig/cases_train.tsv | 3 + data/classic/orig/cases_validate.tsv | 3 + data/classic/orig/genes_to_phenotype.txt | 3 + data/classic/orig/hp.obo | 3 + data/classic/transform.py | 3 + local_data/.gitkeep | 0 requirements/base.txt | 1 + requirements/test.txt | 1 + requirements/tune.txt | 1 + setup.py | 4 + tests/test_quality.py | 21 +++ 24 files changed, 245 insertions(+), 13 deletions(-) create mode 100644 data/classic/cases_test.jsonl create mode 100644 data/classic/cases_train.jsonl create mode 100644 data/classic/cases_validate.jsonl create mode 100644 data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt create mode 100644 data/classic/hgnc_complete_set.json create mode 100644 data/classic/hp.obo create mode 100644 data/classic/orig/ALL_SOURCES_ALL_FREQUENCIES_genes_to_phenotype.txt create mode 100644 data/classic/orig/cases_test.tsv create mode 100644 data/classic/orig/cases_train.tsv create mode 100644 data/classic/orig/cases_validate.tsv create mode 100644 data/classic/orig/genes_to_phenotype.txt create mode 100644 data/classic/orig/hp.obo create mode 100644 data/classic/transform.py create mode 100644 local_data/.gitkeep create mode 100644 requirements/tune.txt create mode 100644 tests/test_quality.py diff --git a/.gitattributes b/.gitattributes index cc824ba..bbe750e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ tests/data/** filter=lfs diff=lfs merge=lfs -text -data/param_opt/* filter=lfs diff=lfs merge=lfs -text +data/** filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 928a71c..d1f1c80 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,29 @@ This is a re-implementation of the [CADA](https://github.com/Chengyao-Peng/CADA) - Discussion Forum: https://github.com/bihealth/cada-prio/discussions - Bug Reports: https://github.com/bihealth/cada-prio/issues +## Running Hyperparameter Tuning + +Install with `tune` feature enabled: + +``` +pip install cada-prio[tune] +``` + +Run tuning, e.g., on the "classic" model. +Thanks to [optuna](https://optuna.org/), you can run this in parallel as long as the database is shared. +Each run will use 4 CPUs in the example below and perform 1 trial. + +``` +cada-prio tune run-optuna \ + sqlite:///local_data/cada-tune.sqlite \ + --path-hgnc-json data/classic/hgnc_complete_set.json \ + --path-hpo-genes-to-phenotype data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt \ + --path-hpo-obo data/classic/hp.obo \ + --path-clinvar-phenotype-links data/classic/cases_train.jsonl \ + --path-validation-links data/classic/cases_validate.jsonl \ + --n-trials 1 \ + --cpus=4 +``` ## Managing GitHub Project with Terraform diff --git a/cada_prio/cli.py b/cada_prio/cli.py index efffc80..5b6b623 100644 --- a/cada_prio/cli.py +++ b/cada_prio/cli.py @@ -1,18 +1,24 @@ """Console script for CADA""" +import json import logging -import os import sys +import tempfile import typing +import cattr import click import logzero -from cada_prio import _version, inspection, param_opt, predict, train_model +try: + import optuna + + _ = optuna + HAVE_OPTUNA = True +except ImportError: + HAVE_OPTUNA = False -# Lower the update interval of tqdm to 5 seconds if stdout is not a TTY. -if not sys.stdout.isatty(): - os.environ["TQDM_MININTERVAL"] = "5" +from cada_prio import _version, inspection, param_opt, predict, train_model @click.group() @@ -117,6 +123,9 @@ def cli_utils(): @cli_utils.command("dump-graph") +@click.option( + "--hgnc-to-entrez/--no-hgnc-to-entrez", help="enable HGNC to Entrez mapping", default=True +) @click.argument("path_graph", type=str) @click.argument("path_hgnc_info", type=str) @click.pass_context @@ -124,10 +133,11 @@ def cli_dump_graph( ctx: click.Context, path_graph: str, path_hgnc_info: str, + hgnc_to_entrez: bool, ): """dump graph edges for debugging""" ctx.ensure_object(dict) - inspection.dump_graph(path_graph, path_hgnc_info) + inspection.dump_graph(path_graph, path_hgnc_info, hgnc_to_entrez) @cli.group("tune") @@ -171,7 +181,7 @@ def cli_tune(): ) @click.option("--cpus", type=int, help="number of CPUs to use", default=1) @click.pass_context -def cli_param_opt( +def cli_train_eval( ctx: click.Context, path_out: str, path_hgnc_json: str, @@ -203,3 +213,101 @@ def cli_param_opt( seed=seed, cpus=cpus, ) + + +if HAVE_OPTUNA: + + @cli_tune.command("run-optuna") + @click.argument("storage", type=str) + @click.option("--n-trials", type=int, help="number of trials to run; default: 100", default=100) + @click.option( + "--study-name", + type=str, + help="name of Optuna study; default: cada-tune", + default="cada-tune", + ) + @click.option("--path-hgnc-json", type=str, help="path to HGNC JSON", required=True) + @click.option( + "--path-hpo-genes-to-phenotype", + type=str, + help="path to genes_to_phenotype.txt file", + required=True, + ) + @click.option("--path-hpo-obo", type=str, help="path HPO OBO file", required=True) + @click.option( + "--path-clinvar-phenotype-links", + type=str, + help="path to ClinVar phenotype links JSONL", + required=True, + ) + @click.option( + "--fraction-links", + type=float, + help="fraction of links to add to the graph (conflicts with --path-validation-links)", + ) + @click.option( + "--path-validation-links", + type=str, + help="path to validation links JSONL (conflicts with --fraction-links)", + ) + @click.option( + "--seed", + type=int, + help="seed for random number generator", + ) + @click.option("--cpus", type=int, help="number of CPUs to use", default=1) + def cli_run_optuna( + storage: str, + study_name: str, + n_trials: int, + path_hgnc_json: str, + path_hpo_genes_to_phenotype: str, + path_hpo_obo: str, + path_clinvar_phenotype_links: str, + fraction_links: typing.Optional[float], + path_validation_links: typing.Optional[str], + seed: typing.Optional[int], + cpus: int, + ): + """run hyperparameter tuning""" + + def objective(trial: optuna.trial.BaseTrial) -> float: + with tempfile.TemporaryDirectory() as tmpdir, open( + f"{tmpdir}/params.json", "wt" + ) as tmpf: + json.dump( + cattr.unstructure( + train_model.EmbeddingParams( + dimensions=trial.suggest_int("dimensions", low=100, high=500, step=10), + walk_length=trial.suggest_int("walk_length", low=1, high=100), + p=trial.suggest_float("p", low=0.1, high=2.5), + q=trial.suggest_float("q", low=0.0, high=1.0), + num_walks=trial.suggest_int("num_walks", low=10, high=50), + window=trial.suggest_int("window", low=4, high=8), + epochs=trial.suggest_int("epochs", low=3, high=6), + use_skipgram=trial.suggest_categorical("use_skipgram", [True, False]), + min_count=trial.suggest_int("batch_words", low=1, high=5), + batch_words=trial.suggest_int("batch_words", low=3, high=6), + ) + ), + tmpf, + indent=2, + ) + tmpf.flush() + result = param_opt.train_and_validate( + path_out=f"{tmpdir}/out", + path_hgnc_json=path_hgnc_json, + path_hpo_genes_to_phenotype=path_hpo_genes_to_phenotype, + path_hpo_obo=path_hpo_obo, + path_clinvar_phenotype_links=path_clinvar_phenotype_links, + fraction_links=fraction_links, + path_validation_links=path_validation_links, + path_embedding_params=f"{tmpdir}/params.json", + seed=seed, + cpus=cpus, + ) + return result[100] + + optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout)) + study = optuna.create_study(study_name=study_name, storage=storage, load_if_exists=True) + study.optimize(objective, n_trials=n_trials) diff --git a/cada_prio/inspection.py b/cada_prio/inspection.py index 54e9a3c..513759c 100644 --- a/cada_prio/inspection.py +++ b/cada_prio/inspection.py @@ -2,19 +2,37 @@ import pickle +from logzero import logger import networkx as nx from cada_prio.predict import load_hgnc_info -def dump_graph(path_graph: str, path_hgnc_info: str): +def dump_graph(path_graph: str, path_hgnc_info: str, hgnc_to_entrez: bool): _, hgnc_info_by_id = load_hgnc_info(path_hgnc_info) + hgnc_info_by_ncbi_gene_id = { + "Entrez:%s" % x.ncbi_gene_id: x for x in hgnc_info_by_id.values() if x.ncbi_gene_id + } with open(path_graph, "rb") as inputf: graph: nx.Graph = pickle.load(inputf) for edge in sorted(graph.edges): lhs, rhs = edge - if lhs.startswith("HGNC:"): - lhs = "Entrez:%s" % hgnc_info_by_id[lhs].ncbi_gene_id - if rhs.startswith("HGNC:"): - rhs = "Entrez:%s" % hgnc_info_by_id[rhs].ncbi_gene_id + + if hgnc_to_entrez: + if lhs.startswith("HGNC:"): + lhs = "Entrez:%s" % hgnc_info_by_id[lhs].ncbi_gene_id + if rhs.startswith("HGNC:"): + rhs = "Entrez:%s" % hgnc_info_by_id[rhs].ncbi_gene_id + else: # entrez to hgnc + if lhs.startswith("Entrez:"): + if lhs not in hgnc_info_by_ncbi_gene_id: + logger.warning("no HGNC ID for Entrez ID %s", lhs) + continue + lhs = hgnc_info_by_ncbi_gene_id[lhs].hgnc_id + if rhs.startswith("Entrez:"): + if rhs not in hgnc_info_by_ncbi_gene_id: + logger.warning("no HGNC ID for Entrez ID %s", rhs) + continue + rhs = hgnc_info_by_ncbi_gene_id[rhs].hgnc_id + print(f"{lhs}\t{rhs}") diff --git a/cada_prio/train_model.py b/cada_prio/train_model.py index b74148c..40f83c0 100644 --- a/cada_prio/train_model.py +++ b/cada_prio/train_model.py @@ -225,10 +225,23 @@ class EmbeddingParams: batch_words: int = 4 #: RNG seed for embedding seed_embedding: int = 1 + #: Whether to use skipgram for model fitting instead of CBOW + use_skipgram: bool = False + #: Number of epochs for fitting + epochs: int = 5 #: RNG seed for fitting seed_fit: int = 1 +def log_graph_examples(edges: typing.List[Edge]): + edges = list(sorted(edges)) + all_nodes = set(itertools.chain([e[0] for e in edges], [e[1] for e in edges])) + hgnc_nodes = list(sorted([v for v in all_nodes if v.startswith("HGNC:")])) + hpo_nodes = list(sorted([v for v in all_nodes if v.startswith("HP:")])) + logger.info(" - ten HGNC nodes: %s", hgnc_nodes[:10]) + logger.info(" - ten HPO nodes: %s", hpo_nodes[:10]) + + def build_and_fit_model( *, clinvar_gen2phen, @@ -247,6 +260,7 @@ def build_and_fit_model( yield_gene2phen_edges(clinvar_gen2phen), ) ) + log_graph_examples(training_edges) logger.info("- graph construction") training_graph = nx.Graph() training_graph.add_edges_from(training_edges) @@ -277,6 +291,8 @@ def build_and_fit_model( min_count=embedding_params.min_count, batch_words=embedding_params.batch_words, seed=embedding_params.seed_fit, + sg=1 if embedding_params.use_skipgram else 0, + epochs=embedding_params.epochs, workers=cpus, ) fit_elapsed = time.time() - fit_start diff --git a/data/classic/cases_test.jsonl b/data/classic/cases_test.jsonl new file mode 100644 index 0000000..c8989cf --- /dev/null +++ b/data/classic/cases_test.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3615b0de5045745e777ac1a8505a6b936c11cc41c64fb8f937b5f2915b9d49b +size 275852 diff --git a/data/classic/cases_train.jsonl b/data/classic/cases_train.jsonl new file mode 100644 index 0000000..df0e5c1 --- /dev/null +++ b/data/classic/cases_train.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fb521614fb561498ce590ccdb132b6a7243a0e8eb0c86e0248ff14e22aad4e4 +size 824010 diff --git a/data/classic/cases_validate.jsonl b/data/classic/cases_validate.jsonl new file mode 100644 index 0000000..7b258fd --- /dev/null +++ b/data/classic/cases_validate.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bc1bcf011e32023201dab7d81af0263aa58fb5d2c3029c6f6a54e8efd366604 +size 271390 diff --git a/data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt b/data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt new file mode 100644 index 0000000..4fd70eb --- /dev/null +++ b/data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6211ca9029533ab68d5d81ad3d60b22854a1a196972b850d4ef71e2b9f38c358 +size 2787100 diff --git a/data/classic/hgnc_complete_set.json b/data/classic/hgnc_complete_set.json new file mode 100644 index 0000000..bcdb24e --- /dev/null +++ b/data/classic/hgnc_complete_set.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7784602e002cf9776d46a3df6eaeb24ceeda6f649068aff56265537d0a3467af +size 34856401 diff --git a/data/classic/hp.obo b/data/classic/hp.obo new file mode 100644 index 0000000..e4b3baf --- /dev/null +++ b/data/classic/hp.obo @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:128ceb286df38adf37d631ca765d392e50b1830049306a9d369619eade18debb +size 6621428 diff --git a/data/classic/orig/ALL_SOURCES_ALL_FREQUENCIES_genes_to_phenotype.txt b/data/classic/orig/ALL_SOURCES_ALL_FREQUENCIES_genes_to_phenotype.txt new file mode 100644 index 0000000..5175305 --- /dev/null +++ b/data/classic/orig/ALL_SOURCES_ALL_FREQUENCIES_genes_to_phenotype.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee9a64b2fead673742dc2185fcb4abe3c88ff1a0adbe5d17eaa6fff8823a4e8b +size 7072518 diff --git a/data/classic/orig/cases_test.tsv b/data/classic/orig/cases_test.tsv new file mode 100644 index 0000000..59bdff1 --- /dev/null +++ b/data/classic/orig/cases_test.tsv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4ef63304270b02ea95cedcbb23191c07856dec40d98f14690f03738e25cfc95 +size 166520 diff --git a/data/classic/orig/cases_train.tsv b/data/classic/orig/cases_train.tsv new file mode 100644 index 0000000..1a79b1f --- /dev/null +++ b/data/classic/orig/cases_train.tsv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:228a395049da39359ccc3dd867ce622479df696bd7ab00faab46795ef33c27c3 +size 497272 diff --git a/data/classic/orig/cases_validate.tsv b/data/classic/orig/cases_validate.tsv new file mode 100644 index 0000000..9b31fe8 --- /dev/null +++ b/data/classic/orig/cases_validate.tsv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8647f19dd42d0d39d5b18ad57b00a1e2e82cc5218fbc073695ff751d966a5fe3 +size 162704 diff --git a/data/classic/orig/genes_to_phenotype.txt b/data/classic/orig/genes_to_phenotype.txt new file mode 100644 index 0000000..aa95061 --- /dev/null +++ b/data/classic/orig/genes_to_phenotype.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:317b5e45be1588b73693e76e3d96f24d14abf4e22f95185e13db4ba1716211f9 +size 14402613 diff --git a/data/classic/orig/hp.obo b/data/classic/orig/hp.obo new file mode 100644 index 0000000..e4b3baf --- /dev/null +++ b/data/classic/orig/hp.obo @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:128ceb286df38adf37d631ca765d392e50b1830049306a9d369619eade18debb +size 6621428 diff --git a/data/classic/transform.py b/data/classic/transform.py new file mode 100644 index 0000000..fa4e01c --- /dev/null +++ b/data/classic/transform.py @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7e7bb336b1e6b74249ec0d59d70a34abf6e507f48c403b6b5e6e89ef49c6853 +size 1161 diff --git a/local_data/.gitkeep b/local_data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/requirements/base.txt b/requirements/base.txt index 86a85d3..ba66a08 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,6 +7,7 @@ tqdm >=4.0 pronto >=2.5, <3.0 networkx node2vec >=0.4.6, <0.5 +gensim >=4.3.2, <5.0 uvicorn >=0.23.2 fastapi >=0.103, <0.104 python-dotenv >=1.0, <2.0 diff --git a/requirements/test.txt b/requirements/test.txt index 56b7a5a..684f57a 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,3 +1,4 @@ +-r tune.txt -r base.txt black ==23.9.1 diff --git a/requirements/tune.txt b/requirements/tune.txt new file mode 100644 index 0000000..22c38a9 --- /dev/null +++ b/requirements/tune.txt @@ -0,0 +1 @@ +optuna >=3.3.0, <3.4.0 diff --git a/setup.py b/setup.py index ee9d107..aca3ef5 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ def parse_requirements(path): test_requirements = parse_requirements("requirements/test.txt") install_requirements = parse_requirements("requirements/base.txt") +tune_requirements = parse_requirements("requirements/tune.txt") package_root = os.path.abspath(os.path.dirname(__file__)) @@ -54,6 +55,9 @@ def parse_requirements(path): description="Phenotype-based prioritization of variants with CADA", entry_points={"console_scripts": ["cada-prio=cada_prio.cli:cli"]}, install_requires=install_requirements, + extras_require={ + "param-tuning": tune_requirements, + }, license="MIT license", long_description=readme + "\n\n" + history, long_description_content_type="text/markdown", diff --git a/tests/test_quality.py b/tests/test_quality.py new file mode 100644 index 0000000..c97cf08 --- /dev/null +++ b/tests/test_quality.py @@ -0,0 +1,21 @@ +"""Model quality tests""" + +from cada_prio import param_opt + + +def test_quality(tmpdir): + """Test quality of model built from 'classic' data from the original paper.""" + result = param_opt.train_and_validate( + path_out=f"{tmpdir}/quality-model", + path_hgnc_json="data/classic/hgnc_complete_set.json", + path_hpo_genes_to_phenotype="data/classic/genes_to_phenotype.all_source_all_freqs_etc.txt", + path_hpo_obo="data/classic/hp.obo", + path_embedding_params=None, + path_clinvar_phenotype_links="/dev/null", + path_validation_links="data/classic/cases_validate.jsonl", + fraction_links=None, + seed=1, + cpus=1, + ) + + assert result[100] >= 56.0, f"result={result}"