Skip to content

Commit

Permalink
Allow branch labels in node-data JSONs
Browse files Browse the repository at this point in the history
Previously branch labels could not be specified in data passed to
`augur export v2` except for two "special cases":
(i) AA mutations (stored in node-data-json -> nodes) would create branch
labels "aa", if applicable.
(ii) `clade_annotation` (stored in node-data-json -> nodes) was
interpreted to be the "clade" branch label, and exported as such.

Here we extend the allowed node-data structure to include a top-level
key `branches` as described in [1] and the test data added here [2].
This data is exported in the appropriate format for Auspice (unchanged).
This paves the way for pipelines to define a range of branch labels for
export. Currently the only usable key in this dict is 'labels'.

If a branch label (via node-data-json -> branches -> node_name -> label)
is provided for 'aa' or 'clade' then this will overwrite the values
generated above (i, ii).

A side-effect of this work is that the requirement for node-data JSONs
to specify "nodes" has been relaxed (see [2] for an example); however
if neither "nodes" nor "branches" are defined then we raise a validation
error.

[1] #720
[2] ./tests/functional/export_v2/branch-labels.json
  • Loading branch information
jameshadfield committed Sep 9, 2022
1 parent 862bbb6 commit e51bb6b
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 52 deletions.
135 changes: 86 additions & 49 deletions augur/export_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,15 @@ def convert_tree_to_json_structure(node, metadata, div=0):

return node_struct

def are_mutations_defined(node_attrs):
for node, data in node_attrs.items():
if data.get("aa_muts") or data.get("muts"):
return True
return False


def are_clades_defined(node_attrs):
for node, data in node_attrs.items():
if data.get("clade_membership") or data.get("clade_annotation"):
def are_mutations_defined(branch_attrs):
for branch in branch_attrs.values():
if branch.get("mutations"):
return True
return False


def are_dates_defined(node_attrs):
def is_node_attr_defined(node_attrs, attr_name):
for node, data in node_attrs.items():
if data.get("num_date"):
if data.get(attr_name):
return True
return False

Expand Down Expand Up @@ -221,7 +213,7 @@ def get_config_colorings_as_dict(config):
return config_colorings


def set_colorings(data_json, config, command_line_colorings, metadata_names, node_data_colorings, provided_colors, node_attrs):
def set_colorings(data_json, config, command_line_colorings, metadata_names, node_data_colorings, provided_colors, node_attrs, branch_attrs):

def _get_type(key, trait_values):
# for some keys we know what the type must be
Expand Down Expand Up @@ -370,7 +362,7 @@ def _create_coloring(key):
def _is_valid(coloring):
key = coloring["key"]
trait_values = get_values_across_nodes(node_attrs, key) # e.g. list of countries, regions etc
if key == "gt" and not are_mutations_defined(node_attrs):
if key == "gt" and not are_mutations_defined(branch_attrs):
warn("[colorings] You asked for mutations (\"gt\"), but none are defined on the tree. They cannot be used as a coloring.")
return False
if key != "gt" and not trait_values:
Expand Down Expand Up @@ -409,11 +401,11 @@ def _get_colorings():

explicitly_defined_colorings = [x["key"] for x in colorings]
# add in genotype as a special case if (a) not already set and (b) the data supports it
if "gt" not in explicitly_defined_colorings and are_mutations_defined(node_attrs):
if "gt" not in explicitly_defined_colorings and are_mutations_defined(branch_attrs):
colorings.insert(0,{'key':'gt'})
if "num_date" not in explicitly_defined_colorings and are_dates_defined(node_attrs):
if "num_date" not in explicitly_defined_colorings and is_node_attr_defined(node_attrs, "num_date"):
colorings.insert(0,{'key':'num_date'})
if "clade_membership" not in explicitly_defined_colorings and are_clades_defined(node_attrs):
if "clade_membership" not in explicitly_defined_colorings and is_node_attr_defined(node_attrs, "clade_membership"):
colorings.insert(0,{'key':'clade_membership'})

return colorings
Expand Down Expand Up @@ -682,10 +674,23 @@ def node_to_author_tuple(data):

return node_author_info

def set_branch_attrs_on_tree(data_json, branch_attrs):
"""
Shifts the provided `branch_attrs` onto the (auspice) `data_json`.
Currently all data is transferred, there is no way for (e.g.) the set of exported
labels to be restricted by the user in a config.
"""
def _recursively_set_data(node):
if branch_attrs.get(node['name'], {}):
node['branch_attrs'] = branch_attrs[node['name']]
for child in node.get("children", []):
_recursively_set_data(child)
_recursively_set_data(data_json["tree"])


def set_node_attrs_on_tree(data_json, node_attrs):
'''
Assign desired colorings, metadata etc to the tree structure
Assign desired colorings, metadata etc to the `node_attrs` of nodes in the tree
Parameters
----------
Expand All @@ -696,33 +701,10 @@ def set_node_attrs_on_tree(data_json, node_attrs):

author_data = create_author_data(node_attrs)

def _transfer_mutations(node, raw_data):
if "aa_muts" in raw_data or "muts" in raw_data:
node["branch_attrs"]["mutations"] = {}
if "muts" in raw_data and len(raw_data["muts"]):
node["branch_attrs"]["mutations"]["nuc"] = raw_data["muts"]
if "aa_muts" in raw_data:
aa = {gene:data for gene, data in raw_data["aa_muts"].items() if len(data)}
node["branch_attrs"]["mutations"].update(aa)
#convert mutations into a label
if aa:
aa_lab = '; '.join("{!s}: {!s}".format(key,', '.join(val)) for (key,val) in aa.items())
if 'labels' in node["branch_attrs"]:
node["branch_attrs"]["labels"]["aa"] = aa_lab
else:
node["branch_attrs"]["labels"] = { "aa": aa_lab }

def _transfer_vaccine_info(node, raw_data):
if raw_data.get("vaccine"):
node["node_attrs"]['vaccine'] = raw_data['vaccine']

def _transfer_labels(node, raw_data):
if "clade_annotation" in raw_data and is_valid(raw_data["clade_annotation"]):
if 'labels' in node["branch_attrs"]:
node["branch_attrs"]["labels"]['clade'] = raw_data["clade_annotation"]
else:
node["branch_attrs"]["labels"] = { "clade": raw_data["clade_annotation"] }

def _transfer_hidden_flag(node, raw_data):
hidden = raw_data.get("hidden", None)
if hidden:
Expand Down Expand Up @@ -771,9 +753,7 @@ def _recursively_set_data(node):
# get all the available information for this particular node
raw_data = node_attrs[node["name"]]
# transfer "special cases"
_transfer_mutations(node, raw_data)
_transfer_vaccine_info(node, raw_data)
_transfer_labels(node, raw_data)
_transfer_hidden_flag(node, raw_data)
_transfer_num_date(node, raw_data)
_transfer_url_accession(node, raw_data)
Expand All @@ -790,8 +770,6 @@ def node_data_prop_is_normal_trait(name):
# those traits / keys / attrs which are not "special" and can be exported
# as normal attributes on nodes
excluded = [
"clade_annotation", # Clade annotation is label, not colorby!
"clade_membership", # will be auto-detected if it is available
"authors", # authors are set as a node property, not a trait property
"author", # see above
"vaccine", # vaccine info is stored as a "special" node prop
Expand Down Expand Up @@ -939,6 +917,51 @@ def set_description(data_json, cmd_line_description_file):
except FileNotFoundError:
fatal("Provided desciption file {} does not exist".format(cmd_line_description_file))

def create_branch_mutations(branch_attrs, node_data):
for node_name, node_info in node_data['nodes'].items():
if node_name not in branch_attrs:
continue # strain name not in the tree
if "aa_muts" not in node_info and "muts" not in node_info:
continue
branch_attrs[node_name]['mutations'] = {}
if "muts" in node_info and len(node_info["muts"]):
branch_attrs[node_name]["mutations"]["nuc"] = node_info["muts"]
if "aa_muts" in node_info:
aa = {gene:data for gene, data in node_info["aa_muts"].items() if len(data)}
branch_attrs[node_name]["mutations"].update(aa)

def create_branch_labels(branch_attrs, node_data, branch_data):
## start by creating the 'aa' branch label, summarising any amino acid mutations.
## (We have already set mutations on 'branch_attrs' if they exist, just not the label)
## This is done first so that if the user defines their own 'aa' labels they will
## overwrite the ones created here
for branch_info in branch_attrs.values():
genes = [gene for gene in branch_info.get('mutations', {}) if gene!='nuc']
if len(genes):
if 'labels' not in branch_info:
branch_info['labels'] = {} ## todo - can we just use defaultdict and avoid all these checks?
branch_info['labels']['aa'] = \
'; '.join("{!s}: {!s}".format(gene,', '.join(branch_info['mutations'][gene])) for gene in genes)

## check for the special key 'clade_annotation' defined via node data.
## For historical reasons, this is interpreted as a branch label 'clade'
for node_name, node_info in node_data.items():
if node_name in branch_attrs and "clade_annotation" in node_info and is_valid(node_info["clade_annotation"]):
if 'labels' not in branch_attrs[node_name]:
branch_attrs[node_name]['labels'] = {}
branch_attrs[node_name]['labels']['clade'] = node_info["clade_annotation"]

## finally transfer any labels defined via <NODE DATA JSON> -> 'branches' -> labels
for node_name, branch_info in branch_data.items():
if node_name not in branch_attrs:
continue
for label_key, label_value in branch_info.get('labels', {}).items():
if not is_valid(label_value):
continue
if 'labels' not in branch_attrs[node_name]:
branch_attrs[node_name]["labels"] = {}
branch_attrs[node_name]["labels"][label_key] = label_value

def parse_node_data_and_metadata(T, node_data, metadata):
node_data_names = set()
metadata_names = set()
Expand All @@ -955,14 +978,24 @@ def parse_node_data_and_metadata(T, node_data, metadata):
metadata_names.add(corrected_key)

# second pass: node data JSONs (overwrites keys of same name found in metadata)
node_attrs_which_are_actually_branch_attrs = ["clade_annotation", "aa_muts", "muts"]
for name, info in node_data['nodes'].items():
if name in node_attrs: # i.e. this node name is in the tree
for key, value in info.items():
if key in node_attrs_which_are_actually_branch_attrs:
continue # these will be handled below
corrected_key = update_deprecated_names(key)
node_attrs[name][corrected_key] = value
node_data_names.add(corrected_key)

return (node_data, node_attrs, node_data_names, metadata_names)
# third pass: create `branch_attrs`. The data comes from
# (a) some keys within `node_data['nodes']` (for legacy reasons)
# (b) the `node_data['branches']` dictionary, which currently only defines labels
branch_attrs = {clade.name: {} for clade in T.root.find_clades()}
create_branch_mutations(branch_attrs, node_data)
create_branch_labels(branch_attrs, node_data['nodes'], node_data.get('branches', {}))

return (node_data, node_attrs, node_data_names, metadata_names, branch_attrs)

def get_config(args):
if not args.auspice_config:
Expand Down Expand Up @@ -1010,7 +1043,8 @@ def run(args):

# parse input files
T = Phylo.read(args.tree, 'newick')
node_data, node_attrs, node_data_names, metadata_names = parse_node_data_and_metadata(T, node_data_file, metadata_file)
node_data, node_attrs, node_data_names, metadata_names, branch_attrs = \
parse_node_data_and_metadata(T, node_data_file, metadata_file)
config = get_config(args)

# set metadata data structures
Expand All @@ -1030,7 +1064,8 @@ def run(args):
metadata_names=metadata_names,
node_data_colorings=node_data_names,
provided_colors=read_colors(args.colors),
node_attrs=node_attrs
node_attrs=node_attrs,
branch_attrs=branch_attrs
)
except FileNotFoundError as e:
print(f"ERROR: required file could not be read: {e}")
Expand All @@ -1040,6 +1075,8 @@ def run(args):
# set tree structure
data_json["tree"] = convert_tree_to_json_structure(T.root, node_attrs)
set_node_attrs_on_tree(data_json, node_attrs)
set_branch_attrs_on_tree(data_json, branch_attrs)

set_geo_resolutions(data_json, config, args.geo_resolutions, read_lat_longs(args.lat_longs), node_attrs)
set_panels(data_json, config, args.panels)
set_data_provenance(data_json, config)
Expand Down
16 changes: 15 additions & 1 deletion augur/util_support/node_data_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def annotations(self):

@property
def nodes(self):
return self.attrs.get("nodes")
return self.attrs.get("nodes", {})

@property
def branches(self):
# these are optional, so we provide an empty dict as a default
return self.attrs.get("branches", {})

@property
def generated_by(self):
Expand Down Expand Up @@ -71,6 +76,15 @@ def validate(self):
f"`nodes` value in {self.fname} is not a dictionary. Please check the formatting of this JSON!"
)

if not isinstance(self.branches, dict):
raise RuntimeError(
f"`branches` value in {self.fname} is not a dictionary. Please check the formatting of this JSON!" )

if not self.nodes and not self.branches:
raise RuntimeError(
f"{self.fname} did not contain either `nodes` or `branches`. Please check the formatting of this JSON!"
)

if not self.skip_validation and self.is_generated_by_incompatible_augur:
raise AugurError(
f"Augur version incompatibility detected: the JSON {self.fname} was generated by "
Expand Down
11 changes: 11 additions & 0 deletions tests/functional/export_v2.t
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,15 @@ Check the output from the above command against its expected contents
> --exclude-paths "root['meta']['updated']"
{}

Export using a node-data file defining branch attrs (mutations, labels)
$ ${AUGUR} export v2 \
> --tree export_v2/tree.nwk \
> --node-data export_v2/div_node-data.json export_v2/nt_muts_1.json export_v2/aa_muts_1.json export_v2/branch-labels.json \
> --maintainers "Nextstrain Team" \
> --output "$TMP/dataset-with-branch-labels.json" > /dev/null

$ python3 "$TESTDIR/../../scripts/diff_jsons.py" export_v2/dataset-with-branch-labels.json "$TMP/dataset-with-branch-labels.json" \
> --exclude-paths "root['meta']['updated']"
{}

$ popd > /dev/null
33 changes: 33 additions & 0 deletions tests/functional/export_v2/aa_muts_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"annotations": {
"gene1": {
"end": 150,
"start": 50,
"strand": "+"
},
"gene2": {
"end": 300,
"start": 200,
"strand": "+"
},
"nuc": {
"end": 500,
"start": 1,
"strand": "+"
}
},
"nodes": {
"internalBC": {
"aa_muts": {
"gene1": ["S10G", "P20S"],
"gene2": []
}
},
"internalDEF": {
"aa_muts": {
"gene1": ["P20S"],
"gene2": []
}
}
}
}
36 changes: 36 additions & 0 deletions tests/functional/export_v2/branch-labels.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"nodes": {
"tipD": {
"clade_membership": "membership D",
"clade_annotation": "set via nodes→clade_annotation"
},
"tipC": {
"clade_membership": "membership C",
"clade_annotation": "this should be overwritten by custom clade label"
}
},
"branches": {
"ROOT": {
"labels": {
"fruit": "apple"
}
},
"tipA": {
"labels": {
"fruit": "orange"
}
},
"tipC": {
"labels": {
"clade": "clade C"
}
},
"internalBC": {
"labels": {
"fruit": "pomegranate",
"vegetable": "pumpkin",
"aa": "custom aa label"
}
}
}
}
Loading

0 comments on commit e51bb6b

Please sign in to comment.