Skip to content

Commit

Permalink
[python] Allow uns-key restriction for ingest and outgest (#1815)
Browse files Browse the repository at this point in the history
* [python] Allow uns-key restriction for ingest and outgest

* typofix

* update on-line help

* code-review feedback

* debug

* no debug
  • Loading branch information
johnkerl authored Oct 25, 2023
1 parent 01440e5 commit 19dd9cc
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 41 deletions.
42 changes: 39 additions & 3 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def from_h5ad(
use_relative_uri: Optional[bool] = None,
X_kind: Union[Type[SparseNDArray], Type[DenseNDArray]] = SparseNDArray,
registration_mapping: Optional[ExperimentAmbientLabelMapping] = None,
uns_keys: Optional[Sequence[str]] = None,
) -> str:
"""Reads an ``.h5ad`` file and writes it to an :class:`Experiment`.
Expand Down Expand Up @@ -308,6 +309,10 @@ def from_h5ad(
registration_mapping=rd,
)
uns_keys: Only ingest the specified top-level ``uns`` keys.
The default is to ingest them all. Use ``uns_keys=[]``
to not ingest any ``uns`` keys.
Returns:
The URI of the newly created experiment.
Expand Down Expand Up @@ -344,6 +349,7 @@ def from_h5ad(
use_relative_uri=use_relative_uri,
X_kind=X_kind,
registration_mapping=registration_mapping,
uns_keys=uns_keys,
)

logging.log_io(
Expand All @@ -366,6 +372,7 @@ def from_anndata(
use_relative_uri: Optional[bool] = None,
X_kind: Union[Type[SparseNDArray], Type[DenseNDArray]] = SparseNDArray,
registration_mapping: Optional[ExperimentAmbientLabelMapping] = None,
uns_keys: Optional[Sequence[str]] = None,
) -> str:
"""Writes an `AnnData <https://anndata.readthedocs.io/>`_ object to an :class:`Experiment`.
Expand Down Expand Up @@ -490,6 +497,7 @@ def from_anndata(
context=context,
ingestion_params=ingestion_params,
use_relative_uri=use_relative_uri,
uns_keys=uns_keys,
)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Expand Down Expand Up @@ -2297,6 +2305,7 @@ def _maybe_ingest_uns(
context: Optional[SOMATileDBContext],
ingestion_params: IngestionParams,
use_relative_uri: Optional[bool],
uns_keys: Optional[Sequence[str]] = None,
) -> None:
# Don't try to ingest an empty uns.
if not uns:
Expand All @@ -2309,6 +2318,7 @@ def _maybe_ingest_uns(
context=context,
ingestion_params=ingestion_params,
use_relative_uri=use_relative_uri,
uns_keys=uns_keys,
)


Expand All @@ -2321,6 +2331,8 @@ def _ingest_uns_dict(
context: Optional[SOMATileDBContext],
ingestion_params: IngestionParams,
use_relative_uri: Optional[bool],
uns_keys: Optional[Sequence[str]] = None,
level: int = 0,
) -> None:
with _create_or_open_collection(
Collection,
Expand All @@ -2331,6 +2343,8 @@ def _ingest_uns_dict(
_maybe_set(parent, parent_key, coll, use_relative_uri=use_relative_uri)
coll.metadata["soma_tiledbsoma_type"] = "uns"
for key, value in dct.items():
if level == 0 and uns_keys is not None and key not in uns_keys:
continue
_ingest_uns_node(
coll,
key,
Expand All @@ -2339,6 +2353,7 @@ def _ingest_uns_dict(
context=context,
ingestion_params=ingestion_params,
use_relative_uri=use_relative_uri,
level=level + 1,
)

msg = f"Wrote {coll.uri} (uns collection)"
Expand All @@ -2354,6 +2369,7 @@ def _ingest_uns_node(
context: Optional[SOMATileDBContext],
ingestion_params: IngestionParams,
use_relative_uri: Optional[bool],
level: int,
) -> None:
if isinstance(value, np.generic):
# This is some kind of numpy scalar value. Metadata entries
Expand All @@ -2375,6 +2391,7 @@ def _ingest_uns_node(
context=context,
ingestion_params=ingestion_params,
use_relative_uri=use_relative_uri,
level=level + 1,
)
return

Expand Down Expand Up @@ -2627,9 +2644,11 @@ def to_h5ad(
X_layer_name: str = "data",
obs_id_name: str = "obs_id",
var_id_name: str = "var_id",
obsm_varm_width_hints: Optional[Dict[str, Dict[str, int]]] = None,
uns_keys: Optional[Sequence[str]] = None,
) -> None:
"""Converts the experiment group to `AnnData <https://anndata.readthedocs.io/>`_
format and writes it to the specified ``.h5ad`` file.
format and writes it to the specified ``.h5ad`` file. Arguments are as in ``to_anndata``.
Lifecycle:
Experimental.
Expand All @@ -2643,6 +2662,8 @@ def to_h5ad(
obs_id_name=obs_id_name,
var_id_name=var_id_name,
X_layer_name=X_layer_name,
obsm_varm_width_hints=obsm_varm_width_hints,
uns_keys=uns_keys,
)

s2 = _util.get_start_stamp()
Expand All @@ -2666,6 +2687,7 @@ def to_anndata(
obs_id_name: str = "obs_id",
var_id_name: str = "var_id",
obsm_varm_width_hints: Optional[Dict[str, Dict[str, int]]] = None,
uns_keys: Optional[Sequence[str]] = None,
) -> ad.AnnData:
"""Converts the experiment group to `AnnData <https://anndata.readthedocs.io/>`_
format. Choice of matrix formats is following what we often see in input
Expand All @@ -2679,6 +2701,10 @@ def to_anndata(
The ``obsm_varm_width_hints`` is optional. If provided, it should be of the form
``{"obsm":{"X_tSNE":2}}`` to aid with export errors.
If ``uns_keys`` is provided, only the specified top-level ``uns`` keys
are extracted. The default is to extract them all. Use ``uns_keys=[]``
to not ingest any ``uns`` keys.
Lifecycle:
Experimental.
"""
Expand Down Expand Up @@ -2763,7 +2789,10 @@ def to_anndata(
if "uns" in measurement:
s = _util.get_start_stamp()
logging.log_io(None, f'Start writing uns for {measurement["uns"].uri}')
uns = _extract_uns(cast(Collection[Any], measurement["uns"]))
uns = _extract_uns(
cast(Collection[Any], measurement["uns"]),
uns_keys=uns_keys,
)
logging.log_io(
None,
_util.format_elapsed(s, f'Finish writing uns for {measurement["uns"].uri}'),
Expand Down Expand Up @@ -2861,15 +2890,20 @@ def _extract_obsm_or_varm(

def _extract_uns(
collection: Collection[Any],
uns_keys: Optional[Sequence[str]] = None,
level: int = 0,
) -> Dict[str, Any]:
"""
This is a helper function for ``to_anndata`` of ``uns`` elements.
"""

extracted: Dict[str, Any] = {}
for key, element in collection.items():
if level == 0 and uns_keys is not None and key not in uns_keys:
continue

if isinstance(element, Collection):
extracted[key] = _extract_uns(element)
extracted[key] = _extract_uns(element, level=level + 1)
elif isinstance(element, DataFrame):
hint = element.metadata.get(_UNS_OUTGEST_HINT_KEY)
pdf = element.read().concat().to_pandas()
Expand Down Expand Up @@ -2897,6 +2931,8 @@ def _extract_uns(

# Primitives got set on the SOMA-experiment uns metadata.
for key, value in collection.metadata.items():
if level == 0 and uns_keys is not None and key not in uns_keys:
continue
if not key.startswith("soma_"):
extracted[key] = value

Expand Down
99 changes: 61 additions & 38 deletions apis/python/tests/test_basic_anndata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,38 +364,48 @@ def test_ingest_relative(h5ad_file_extended, use_relative_uri):
exp.close()


def test_ingest_uns(tmp_path: pathlib.Path, h5ad_file_extended):
@pytest.mark.parametrize("ingest_uns_keys", [["louvain_colors"], None])
def test_ingest_uns(tmp_path: pathlib.Path, h5ad_file_extended, ingest_uns_keys):
tmp_uri = tmp_path.as_uri()
original = anndata.read(h5ad_file_extended)
uri = tiledbsoma.io.from_anndata(tmp_uri, original, measurement_name="hello")
uri = tiledbsoma.io.from_anndata(
tmp_uri,
original,
measurement_name="hello",
uns_keys=ingest_uns_keys,
)

with tiledbsoma.Experiment.open(uri) as exp:
uns = exp.ms["hello"]["uns"]
assert isinstance(uns, tiledbsoma.Collection)
assert uns.metadata["soma_tiledbsoma_type"] == "uns"
assert set(uns) == {
"draw_graph",
"louvain",
"louvain_colors",
"neighbors",
"pca",
"rank_genes_groups",
}
rgg = uns["rank_genes_groups"]
assert set(rgg) == {"params"}, "structured arrays not imported"
assert rgg["params"].metadata.items() >= {
("groupby", "louvain"),
("method", "t-test_overestim_var"),
("reference", "rest"),
}
dg_params = uns["draw_graph"]["params"]
assert isinstance(dg_params, tiledbsoma.Collection)
assert dg_params.metadata["layout"] == "fr"
random_state = dg_params["random_state"]
assert isinstance(random_state, tiledbsoma.DenseNDArray)
assert np.array_equal(random_state.read().to_numpy(), np.array([0]))
got_pca_variance = uns["pca"]["variance"].read().to_numpy()
assert np.array_equal(got_pca_variance, original.uns["pca"]["variance"])
if ingest_uns_keys is None:
assert set(uns) == {
"draw_graph",
"louvain",
"louvain_colors",
"neighbors",
"pca",
"rank_genes_groups",
}
assert isinstance(uns["louvain_colors"], tiledbsoma.DataFrame)
rgg = uns["rank_genes_groups"]
assert set(rgg) == {"params"}, "structured arrays not imported"
assert rgg["params"].metadata.items() >= {
("groupby", "louvain"),
("method", "t-test_overestim_var"),
("reference", "rest"),
}
dg_params = uns["draw_graph"]["params"]
assert isinstance(dg_params, tiledbsoma.Collection)
assert dg_params.metadata["layout"] == "fr"
random_state = dg_params["random_state"]
assert isinstance(random_state, tiledbsoma.DenseNDArray)
assert np.array_equal(random_state.read().to_numpy(), np.array([0]))
got_pca_variance = uns["pca"]["variance"].read().to_numpy()
assert np.array_equal(got_pca_variance, original.uns["pca"]["variance"])
else:
assert set(uns) == set(ingest_uns_keys)


def test_ingest_uns_string_arrays(h5ad_file_uns_string_arrays):
Expand Down Expand Up @@ -833,7 +843,10 @@ def test_id_names(tmp_path, obs_id_name, var_id_name, indexify_obs, indexify_var
assert list(bdata.var.index) == list(soma_var[var_id_name])


def test_uns_io(tmp_path):
@pytest.mark.parametrize(
"outgest_uns_keys", [["int_scalar", "strings", "np_ndarray_2d"], None]
)
def test_uns_io(tmp_path, outgest_uns_keys):
obs = pd.DataFrame(
data={"obs_id": np.asarray(["a", "b", "c"])},
index=np.arange(3).astype(str),
Expand Down Expand Up @@ -882,23 +895,33 @@ def test_uns_io(tmp_path):
tiledbsoma.io.from_anndata(soma_uri, adata, measurement_name="RNA")

with tiledbsoma.Experiment.open(soma_uri) as exp:
bdata = tiledbsoma.io.to_anndata(exp, measurement_name="RNA")
bdata = tiledbsoma.io.to_anndata(
exp,
measurement_name="RNA",
uns_keys=outgest_uns_keys,
)

# Keystroke-savers
a = adata.uns
b = bdata.uns

assert a["int_scalar"] == b["int_scalar"]
assert a["float_scalar"] == b["float_scalar"]
assert a["string_scalar"] == b["string_scalar"]
if outgest_uns_keys is None:
assert a["int_scalar"] == b["int_scalar"]
assert a["float_scalar"] == b["float_scalar"]
assert a["string_scalar"] == b["string_scalar"]

assert all(a["pd_df_indexed"]["column_1"] == b["pd_df_indexed"]["column_1"])
assert all(
a["pd_df_nonindexed"]["column_1"] == b["pd_df_nonindexed"]["column_1"]
)

assert all(a["pd_df_indexed"]["column_1"] == b["pd_df_indexed"]["column_1"])
assert all(a["pd_df_nonindexed"]["column_1"] == b["pd_df_nonindexed"]["column_1"])
assert (a["np_ndarray_1d"] == b["np_ndarray_1d"]).all()
assert (a["np_ndarray_2d"] == b["np_ndarray_2d"]).all()

assert (a["np_ndarray_1d"] == b["np_ndarray_1d"]).all()
assert (a["np_ndarray_2d"] == b["np_ndarray_2d"]).all()
sa = a["strings"]
sb = b["strings"]
assert (sa["string_np_ndarray_1d"] == sb["string_np_ndarray_1d"]).all()
assert (sa["string_np_ndarray_2d"] == sb["string_np_ndarray_2d"]).all()

sa = a["strings"]
sb = b["strings"]
assert (sa["string_np_ndarray_1d"] == sb["string_np_ndarray_1d"]).all()
assert (sa["string_np_ndarray_2d"] == sb["string_np_ndarray_2d"]).all()
else:
assert set(b.keys()) == set(outgest_uns_keys)

0 comments on commit 19dd9cc

Please sign in to comment.