Skip to content

Commit

Permalink
Merge pull request #173 from nextstrain/add-derived-haplotypes-for-al…
Browse files Browse the repository at this point in the history
…l-sequences

Summarize haplotype coverage by titer references using frequencies per haplotype from all available data
  • Loading branch information
huddlej authored Jul 25, 2024
2 parents bf73f72 + dd9aae5 commit 8849483
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
data/
builds/
results/
tables/
auspice/
auspice-who/
auspice_renamed/
Expand Down
7 changes: 0 additions & 7 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,6 @@ max-line-length=100
# Maximum number of lines in a module
max-module-lines=1000

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
Expand Down
94 changes: 93 additions & 1 deletion profiles/nextflu-private/report.smk
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
rule all_report_outputs:
input:
derived_haplotypes=expand("tables/{lineage}/derived_haplotypes.md", lineage=["h1n1pdm", "h3n2", "vic"]),
counts_by_clade=expand("tables/{lineage}/counts_of_recent_sequences_by_clade.md", lineage=["h1n1pdm", "h3n2", "vic"]),
total_sample_count_by_lineage="figures/total-sample-count-by-lineage.png",

Expand Down Expand Up @@ -66,10 +67,22 @@ rule download_nextclade:
aws s3 cp {params.s3_path} {output.nextclade}
"""

rule filter_nextclade_by_qc:
input:
nextclade="data/{lineage}/{segment}/nextclade.tsv.xz",
output:
nextclade="data/{lineage}/{segment}/nextclade_without_bad_qc.tsv",
conda: "../../workflow/envs/nextstrain.yaml"
shell:
"""
xz -c -d {input.nextclade} \
| tsv-filter -H --str-ne "qc.overallStatus:bad" > {output.nextclade}
"""

rule count_recent_tips_by_clade:
input:
recency="tables/{lineage}/recency.json",
clades="data/{lineage}/ha/nextclade.tsv.xz",
clades="data/{lineage}/ha/nextclade_without_bad_qc.tsv",
output:
counts="tables/{lineage}/counts_of_recent_sequences_by_clade.md",
conda: "../../workflow/envs/nextstrain.yaml"
Expand All @@ -80,3 +93,82 @@ rule count_recent_tips_by_clade:
--clades {input.clades} \
--output {output.counts}
"""

rule get_derived_haplotypes:
input:
nextclade="data/{lineage}/ha/nextclade_without_bad_qc.tsv",
output:
haplotypes="data/{lineage}/nextclade_with_derived_haplotypes.tsv",
conda: "../../workflow/envs/nextstrain.yaml"
params:
genes=["HA1"],
shell:
"""
python3 scripts/add_derived_haplotypes.py \
--nextclade {input.nextclade} \
--genes {params.genes:q} \
--strip-genes \
--output {output.haplotypes}
"""

rule join_metadata_and_nextclade:
input:
metadata="data/{lineage}/metadata.tsv",
nextclade="data/{lineage}/nextclade_with_derived_haplotypes.tsv",
output:
metadata="data/{lineage}/metadata_with_derived_haplotypes.tsv",
conda: "../../workflow/envs/nextstrain.yaml"
shell:
"""
tsv-join -H -f {input.nextclade} -a haplotype -k seqName -d strain {input.metadata} > {output.metadata}
"""

rule estimate_derived_haplotype_frequencies:
input:
metadata="data/{lineage}/metadata_with_derived_haplotypes.tsv",
output:
frequencies="tables/{lineage}/derived_haplotype_frequencies.json",
conda: "../../workflow/envs/nextstrain.yaml"
params:
narrow_bandwidth=1 / 12.0,
min_date="16W",
max_date=config.get("build_date", "4W"),
shell:
"""
python3 scripts/estimate_frequencies_from_metadata.py \
--metadata {input.metadata} \
--narrow-bandwidth {params.narrow_bandwidth} \
--min-date {params.min_date} \
--max-date {params.max_date} \
--output {output.frequencies}
"""

rule summarize_derived_haplotypes:
input:
metadata="data/{lineage}/metadata_with_derived_haplotypes.tsv",
frequencies="tables/{lineage}/derived_haplotype_frequencies.json",
titers=lambda wildcards: [
collection["data"]
for collection in config["builds"][f"{wildcards.lineage}_2y_titers"]["titer_collections"]
if "ferret" in collection["data"]
],
output:
table="tables/{lineage}/derived_haplotypes.tsv",
markdown_table="tables/{lineage}/derived_haplotypes.md",
conda: "../../workflow/envs/nextstrain.yaml"
params:
titer_names=lambda wildcards: [
collection["name"]
for collection in config["builds"][f"{wildcards.lineage}_2y_titers"]["titer_collections"]
if "ferret" in collection["data"]
],
shell:
"""
python3 scripts/summarize_haplotypes.py \
--metadata {input.metadata} \
--frequencies {input.frequencies} \
--titers {input.titers:q} \
--titer-names {params.titer_names:q} \
--output-table {output.table} \
--output-markdown-table {output.markdown_table}
"""
80 changes: 80 additions & 0 deletions scripts/add_derived_haplotypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Annotate derived haplotypes per node from annotated clades and store as node data JSON.
"""
import argparse
import pandas as pd


def create_haplotype_for_record(record, clade_column, mutations_column, genes=None, strip_genes=False):
"""Create a haplotype string for the given record based on the values in its
clade and mutations column. If a list of genes is given, filter mutations to
only those in the requested genes.
"""
clade = record[clade_column]
mutations = record[mutations_column].split(",")

# Filter mutations to requested genes.
if genes is not None:
mutations = [
mutation
for mutation in mutations
if mutation.split(":")[0] in genes
]

mutations = "-".join(mutations).replace(":", "-")

if mutations:
if strip_genes and genes is not None:
for gene in genes:
mutations = mutations.replace(f"{gene}-", "")

return f"{clade}:{mutations}"
else:
return clade


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Annotate derived haplotypes per record in Nextclade annotations",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument("--nextclade", required=True, help="TSV file of Nextclade annotations with columns for clade and AA mutations derived from clade")
parser.add_argument("--clade-column", help="name of the branch attribute for clade labels in the given Nextclade annotations", default="subclade")
parser.add_argument("--mutations-column", help="name of the attribute for mutations relative to clades in the given Nextclade annotations", default="founderMuts['subclade'].aaSubstitutions")
parser.add_argument("--genes", nargs="+", help="list of genes to filter mutations to. If not provided, all mutations will be used.")
parser.add_argument("--strip-genes", action="store_true", help="strip gene names from coordinates in output haplotypes")
parser.add_argument("--attribute-name", default="haplotype", help="name of attribute to store the derived haplotype in the output file")
parser.add_argument("--output", help="TSV file of Nextclade annotations with derived haplotype column added", required=True)
args = parser.parse_args()

# Load Nextclade annotations.
df = pd.read_csv(
args.nextclade,
sep="\t",
dtype={
args.clade_column: "str",
args.mutations_column: "str",
},
na_filter=False,
)

# Annotate derived haplotypes.
df[args.attribute_name] = df.apply(
lambda record: create_haplotype_for_record(
record,
args.clade_column,
args.mutations_column,
args.genes,
args.strip_genes,
),
axis=1
)

# Save updated Nextclade annotations
df.to_csv(
args.output,
sep="\t",
index=False,
)
10 changes: 5 additions & 5 deletions scripts/annotate_haplotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@
mutations = []
for i in range(len(sequence_by_node[node.name])):
if sequence_by_node[node.name][i] != sequence_by_clade[clade][i]:
# Store 1-based mutation position and derived allele.
mutations.append(f"{i + 1}{sequence_by_node[node.name][i]}")
# Store ancestral allele, 1-based mutation position, and derived allele.
mutations.append(f"{sequence_by_clade[clade][i]}{i + 1}{sequence_by_node[node.name][i]}")

# Store the clade name plus a comma-delimited list of derived
# mutations present in the current node.
haplotype = f"{clade}:{','.join(mutations)}"
# Store the clade name plus a delimited list of derived mutations
# present in the current node.
haplotype = f"{clade}:{'-'.join(mutations)}"

# Store the clade and haplotype values for this node.
haplotypes[node.name][args.attribute_name] = haplotype
Expand Down
3 changes: 1 addition & 2 deletions scripts/count_recent_tips_by_clade.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
clades = pd.read_csv(
args.clades,
sep="\t",
usecols=["seqName", "subclade", "qc.overallStatus"],
usecols=["seqName", "subclade"],
)


# Filter clade labels to recent non-low-quality sequences and count the
# clade membership for each recent tip.
count_by_clade = clades[
(clades["qc.overallStatus"] != "bad") &
(clades["seqName"].isin(recent_tips))
].groupby(
"subclade"
Expand Down
73 changes: 73 additions & 0 deletions scripts/estimate_frequencies_from_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3
import argparse
import numpy as np

from augur.dates import get_numerical_dates, numeric_date_type
from augur.frequencies import format_frequencies
from augur.frequency_estimators import get_pivots, KdeFrequencies
from augur.io import read_metadata
from augur.utils import write_json


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Estimate sequence frequencies from metadata with collection dates",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument("--metadata", required=True, help="TSV file of metadata with at least 'strain' and 'date' columns")
parser.add_argument("--narrow-bandwidth", required=True, type=float, help="narrow bandwidth for KDE frequencies")
parser.add_argument("--proportion-wide", type=float, default=0.0, help="proportion of wide bandwidth to use for KDE frequencies")
parser.add_argument("--pivot-interval", type=int, default=4, help="interval between pivots in weeks")
parser.add_argument("--min-date", type=numeric_date_type, help="minimum date to estimate frequencies for")
parser.add_argument("--max-date", type=numeric_date_type, help="maximum date to estimate frequencies for")
parser.add_argument("--output", required=True, help="JSON file in tip-frequencies format")
args = parser.parse_args()

columns_to_load = ["strain", "date"]
metadata = read_metadata(
args.metadata,
columns=columns_to_load,
dtype="string",
)
dates = get_numerical_dates(metadata, fmt='%Y-%m-%d')

strains = []
observations = []
for strain in metadata.index.values:
if dates.get(strain):
strains.append(strain)
observations.append(np.mean(dates[strain]))

pivots = get_pivots(
observations,
args.pivot_interval,
args.min_date,
args.max_date,
"weeks",
)

frequencies = KdeFrequencies(
sigma_narrow=args.narrow_bandwidth,
proportion_wide=args.proportion_wide,
pivot_frequency=args.pivot_interval,
start_date=args.min_date,
end_date=args.max_date,
)
frequency_matrix = frequencies.estimate_frequencies(
observations,
pivots,
)
tip_frequencies = {
strain: frequency_matrix[index]
for index, strain in enumerate(strains)
if frequency_matrix[index].sum() > 0
}

frequency_dict = {"pivots": list(pivots)}
for node_name in tip_frequencies:
frequency_dict[node_name] = {
"frequencies": format_frequencies(tip_frequencies[node_name])
}

write_json(frequency_dict, args.output)
Loading

0 comments on commit 8849483

Please sign in to comment.