From 6d4390d8342c5c235e0be9fbff356ef1de786402 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 2 Jan 2025 18:43:30 +0000 Subject: [PATCH] read_swc: deal with potential additional columns + formatting --- navis/io/swc_io.py | 222 +++++++++++++++++++++++++-------------------- 1 file changed, 123 insertions(+), 99 deletions(-) diff --git a/navis/io/swc_io.py b/navis/io/swc_io.py index debddaa5..d42faead 100644 --- a/navis/io/swc_io.py +++ b/navis/io/swc_io.py @@ -31,12 +31,12 @@ # Set up logging logger = config.get_logger(__name__) -NODE_COLUMNS = ('node_id', 'label', 'x', 'y', 'z', 'radius', 'parent_id') +NODE_COLUMNS = ("node_id", "label", "x", "y", "z", "radius", "parent_id") COMMENT = "#" DEFAULT_DELIMITER = " " DEFAULT_PRECISION = 32 DEFAULT_FMT = "{name}.swc" -NA_VALUES = [None, 'None'] +NA_VALUES = [None, "None"] class SwcReader(base.BaseReader): @@ -48,17 +48,15 @@ def __init__( precision: int = DEFAULT_PRECISION, read_meta: bool = False, fmt: str = DEFAULT_FMT, - errors: str = 'raise', - attrs: Optional[Dict[str, Any]] = None + errors: str = "raise", + attrs: Optional[Dict[str, Any]] = None, ): - if not fmt.endswith('.swc'): + if not fmt.endswith(".swc"): raise ValueError('`fmt` must end with ".swc"') - super().__init__(fmt=fmt, - attrs=attrs, - file_ext='.swc', - errors=errors, - name_fallback='SWC') + super().__init__( + fmt=fmt, attrs=attrs, file_ext=".swc", errors=errors, name_fallback="SWC" + ) self.connector_labels = connector_labels or dict() self.soma_label = soma_label self.delimiter = delimiter @@ -66,19 +64,19 @@ def __init__( int_, float_ = base.parse_precision(precision) self._dtypes = { - 'node_id': int_, - 'parent_id': int_, - 'label': 'category', - 'x': float_, - 'y': float_, - 'z': float_, - 'radius': float_, + "node_id": int_, + "parent_id": int_, + "label": "category", + "x": float_, + "y": float_, + "z": float_, + "radius": float_, } @base.handle_errors def read_buffer( self, f: IO, attrs: Optional[Dict[str, Any]] = None - ) -> 'core.TreeNeuron': + ) -> "core.TreeNeuron": """Read buffer into a TreeNeuron. Parameters @@ -107,9 +105,22 @@ def read_buffer( skiprows=len(header_rows), comment=COMMENT, header=None, - na_values=NA_VALUES + na_values=NA_VALUES, ) - nodes.columns = NODE_COLUMNS + if len(nodes.columns) < len(NODE_COLUMNS): + raise ValueError("Not enough columns in SWC file.") + elif len(nodes.columns) > len(NODE_COLUMNS): + logger.warning( + f"Found {len(nodes.columns)} instead of the expected 7 " + "columns in SWC file. Assuming additional columns are " + "custom properties. You can silence this warning by setting " + "`navis.set_loggers('ERROR')`." + ) + nodes.columns = ( + list(NODE_COLUMNS) + nodes.columns[len(NODE_COLUMNS) :].tolist() + ) + else: + nodes.columns = NODE_COLUMNS except pd.errors.EmptyDataError: # If file is totally empty, return an empty neuron # Note that the TreeNeuron will still complain but it's a better @@ -119,17 +130,19 @@ def read_buffer( # Check for row with JSON-formatted meta data # Expected format '# Meta: {"id": "12345"}' if self.read_meta: - meta_row = [r for r in header_rows if r.lower().startswith('# meta:')] + meta_row = [r for r in header_rows if r.lower().startswith("# meta:")] if meta_row: meta_data = json.loads(meta_row[0][7:].strip()) attrs = base.merge_dicts(meta_data, attrs) - return self.read_dataframe(nodes, base.merge_dicts({'swc_header': '\n'.join(header_rows)}, attrs)) + return self.read_dataframe( + nodes, base.merge_dicts({"swc_header": "\n".join(header_rows)}, attrs) + ) @base.handle_errors def read_dataframe( self, nodes: pd.DataFrame, attrs: Optional[Dict[str, Any]] = None - ) -> 'core.TreeNeuron': + ) -> "core.TreeNeuron": """Convert a SWC-like DataFrame into a TreeNeuron. Parameters @@ -143,20 +156,19 @@ def read_dataframe( core.TreeNeuron """ n = core.TreeNeuron( - sanitise_nodes( - nodes.astype(self._dtypes, errors='ignore', copy=False) - ), - connectors=self._extract_connectors(nodes)) + sanitise_nodes(nodes.astype(self._dtypes, errors="ignore", copy=False)), + connectors=self._extract_connectors(nodes), + ) if self.soma_label is not None: is_soma_node = n.nodes.label.values == self.soma_label if any(is_soma_node): n.soma = n.nodes.node_id.values[is_soma_node][0] - attrs = self._make_attributes({'name': 'SWC', 'origin': 'DataFrame'}, attrs) + attrs = self._make_attributes({"name": "SWC", "origin": "DataFrame"}, attrs) # SWC is special - we do not want to register it - n.swc_header = attrs.pop('swc_header', '') + n.swc_header = attrs.pop("swc_header", "") # Try adding properties one-by-one. If one fails, we'll keep track of it # in the `.meta` attribute @@ -172,9 +184,7 @@ def read_dataframe( return n - def _extract_connectors( - self, nodes: pd.DataFrame - ) -> Optional[pd.DataFrame]: + def _extract_connectors(self, nodes: pd.DataFrame) -> Optional[pd.DataFrame]: """Infer outgoing/incoming connectors from node labels. Parameters @@ -190,14 +200,12 @@ def _extract_connectors( return None to_concat = [ - pd.DataFrame( - [], columns=['node_id', 'connector_id', 'type', 'x', 'y', 'z'] - ) + pd.DataFrame([], columns=["node_id", "connector_id", "type", "x", "y", "z"]) ] for name, val in self.connector_labels.items(): - cn = nodes[nodes.label == val][['node_id', 'x', 'y', 'z']].copy() - cn['connector_id'] = None - cn['type'] = name + cn = nodes[nodes.label == val][["node_id", "x", "y", "z"]].copy() + cn["connector_id"] = None + cn["type"] = name to_concat.append(cn) return pd.concat(to_concat, axis=0) @@ -215,9 +223,9 @@ def sanitise_nodes(nodes: pd.DataFrame, allow_empty=True) -> pd.DataFrame: pandas.DataFrame """ if not allow_empty and nodes.empty: - raise ValueError('No data found in SWC.') + raise ValueError("No data found in SWC.") - is_na = nodes[['node_id', 'parent_id', 'x', 'y', 'z']].isna().any(axis=1) + is_na = nodes[["node_id", "parent_id", "x", "y", "z"]].isna().any(axis=1) if is_na.any(): # Remove nodes with missing data @@ -225,7 +233,7 @@ def sanitise_nodes(nodes: pd.DataFrame, allow_empty=True) -> pd.DataFrame: # Because we removed nodes, we'll have to run a more complicated root # detection - nodes.loc[~nodes.parent_id.isin(nodes.node_id), 'parent_id'] = -1 + nodes.loc[~nodes.parent_id.isin(nodes.node_id), "parent_id"] = -1 return nodes @@ -424,27 +432,31 @@ def read_swc( failed = [] for n in core.NeuronList(res): - if not hasattr(n, 'meta'): + if not hasattr(n, "meta"): continue failed += list(n.meta.keys()) if failed: failed = list(set(failed)) - logger.warning('Some meta data could not be directly attached to the ' - 'neuron(s) - probably some clash with intrinsic ' - 'properties. You can find these data attached as ' - '`.meta` dictionary.') + logger.warning( + "Some meta data could not be directly attached to the " + "neuron(s) - probably some clash with intrinsic " + "properties. You can find these data attached as " + "`.meta` dictionary." + ) return res -def write_swc(x: 'core.NeuronObject', - filepath: Union[str, Path], - header: Optional[str] = None, - write_meta: Union[bool, List[str], dict] = True, - labels: Union[str, dict, bool] = True, - export_connectors: bool = False, - return_node_map: bool = False) -> None: +def write_swc( + x: "core.NeuronObject", + filepath: Union[str, Path], + header: Optional[str] = None, + write_meta: Union[bool, List[str], dict] = True, + labels: Union[str, dict, bool] = True, + export_connectors: bool = False, + return_node_map: bool = False, +) -> None: """Write TreeNeuron(s) to SWC. Follows the format specified @@ -547,40 +559,50 @@ def write_swc(x: 'core.NeuronObject', if not isinstance(n, core.TreeNeuron): msg = f'Can only write TreeNeurons to SWC, not "{type(n)}"' if isinstance(n, core.Dotprops): - msg += (". For Dotprops, you can use either `navis.write_nrrd`" - " or `navis.write_parquet`.") + msg += ( + ". For Dotprops, you can use either `navis.write_nrrd`" + " or `navis.write_parquet`." + ) raise TypeError(msg) elif not isinstance(x, core.TreeNeuron): msg = f'Can only write TreeNeurons to SWC, not "{type(n)}"' if isinstance(n, core.Dotprops): - msg += (". For Dotprops, you can use either `navis.write_nrrd`" - " or `navis.write_parquet`.") + msg += ( + ". For Dotprops, you can use either `navis.write_nrrd`" + " or `navis.write_parquet`." + ) raise TypeError(msg) - writer = base.Writer(write_func=_write_swc, ext='.swc') + writer = base.Writer(write_func=_write_swc, ext=".swc") - return writer.write_any(x, - filepath=filepath, - header=header, - write_meta=write_meta, - labels=labels, - export_connectors=export_connectors, - return_node_map=return_node_map) + return writer.write_any( + x, + filepath=filepath, + header=header, + write_meta=write_meta, + labels=labels, + export_connectors=export_connectors, + return_node_map=return_node_map, + ) -def _write_swc(x: Union['core.TreeNeuron', 'core.Dotprops'], - filepath: Union[str, Path], - header: Optional[str] = None, - write_meta: Union[bool, List[str], dict] = True, - labels: Union[str, dict, bool] = True, - export_connectors: bool = False, - return_node_map: bool = False) -> None: +def _write_swc( + x: Union["core.TreeNeuron", "core.Dotprops"], + filepath: Union[str, Path], + header: Optional[str] = None, + write_meta: Union[bool, List[str], dict] = True, + labels: Union[str, dict, bool] = True, + export_connectors: bool = False, + return_node_map: bool = False, +) -> None: """Write single TreeNeuron to file.""" # Generate SWC table - res = make_swc_table(x, - labels=labels, - export_connectors=export_connectors, - return_node_map=return_node_map) + res = make_swc_table( + x, + labels=labels, + export_connectors=export_connectors, + return_node_map=return_node_map, + ) if return_node_map: swc, node_map = res[0], res[1] @@ -602,7 +624,7 @@ def _write_swc(x: Union['core.TreeNeuron', 'core.Dotprops'], elif isinstance(write_meta, list): props = {k: str(getattr(x, k, None)) for k in write_meta} else: - props = {k: str(getattr(x, k, None)) for k in ['id', 'name', 'units']} + props = {k: str(getattr(x, k, None)) for k in ["id", "name", "units"]} header += f"# Meta: {json.dumps(props)}\n" header += dedent("""\ # PointNo Label X Y Z Radius Parent @@ -613,25 +635,27 @@ def _write_swc(x: Union['core.TreeNeuron', 'core.Dotprops'], header += dedent("""\ # 7 = presynapses, 8 = postsynapses """) - elif not header.endswith('\n'): - header += '\n' + elif not header.endswith("\n"): + header += "\n" - with open(filepath, 'w') as file: + with open(filepath, "w") as file: # Write header file.write(header) # Write data - writer = csv.writer(file, delimiter=' ') + writer = csv.writer(file, delimiter=" ") writer.writerows(swc.astype(str).values) if return_node_map: return node_map -def make_swc_table(x: Union['core.TreeNeuron', 'core.Dotprops'], - labels: Union[str, dict, bool] = None, - export_connectors: bool = False, - return_node_map: bool = False) -> pd.DataFrame: +def make_swc_table( + x: Union["core.TreeNeuron", "core.Dotprops"], + labels: Union[str, dict, bool] = None, + export_connectors: bool = False, + return_node_map: bool = False, +) -> pd.DataFrame: """Generate a node table compliant with the SWC format. Follows the format specified @@ -673,28 +697,28 @@ def make_swc_table(x: Union['core.TreeNeuron', 'core.Dotprops'], swc = x.nodes.copy() # Add labels - swc['label'] = 0 + swc["label"] = 0 if isinstance(labels, dict): - swc['label'] = swc.index.map(labels) + swc["label"] = swc.index.map(labels) elif isinstance(labels, str): - swc['label'] = swc[labels] + swc["label"] = swc[labels] elif labels: # Add end/branch labels - swc.loc[swc.type == 'branch', 'label'] = 5 - swc.loc[swc.type == 'end', 'label'] = 6 + swc.loc[swc.type == "branch", "label"] = 5 + swc.loc[swc.type == "end", "label"] = 6 # Add soma label if not isinstance(x.soma, type(None)): soma = utils.make_iterable(x.soma) - swc.loc[swc.node_id.isin(soma), 'label'] = 1 + swc.loc[swc.node_id.isin(soma), "label"] = 1 if export_connectors: # Add synapse label pre_ids = x.presynapses.node_id.values post_ids = x.postsynapses.node_id.values - swc.loc[swc.node_id.isin(pre_ids), 'label'] = 7 - swc.loc[swc.node_id.isin(post_ids), 'label'] = 8 + swc.loc[swc.node_id.isin(pre_ids), "label"] = 7 + swc.loc[swc.node_id.isin(post_ids), "label"] = 8 # Sort such that the parent is always before the child - swc.sort_values('parent_id', ascending=True, inplace=True) + swc.sort_values("parent_id", ascending=True, inplace=True) # Reset index swc.reset_index(drop=True, inplace=True) @@ -702,18 +726,18 @@ def make_swc_table(x: Union['core.TreeNeuron', 'core.Dotprops'], # Generate mapping new_ids = dict(zip(swc.node_id.values, swc.index.values + 1)) - swc['node_id'] = swc.node_id.map(new_ids) + swc["node_id"] = swc.node_id.map(new_ids) # Lambda prevents potential issue with missing parents - swc['parent_id'] = swc.parent_id.map(lambda x: new_ids.get(x, -1)) + swc["parent_id"] = swc.parent_id.map(lambda x: new_ids.get(x, -1)) # Get things in order - swc = swc[['node_id', 'label', 'x', 'y', 'z', 'radius', 'parent_id']] + swc = swc[["node_id", "label", "x", "y", "z", "radius", "parent_id"]] # Make sure radius has no `None` - swc['radius'] = swc.radius.fillna(0) + swc["radius"] = swc.radius.fillna(0) # Adjust column titles - swc.columns = ['PointNo', 'Label', 'X', 'Y', 'Z', 'Radius', 'Parent'] + swc.columns = ["PointNo", "Label", "X", "Y", "Z", "Radius", "Parent"] if return_node_map: return swc, new_ids