Skip to content

Commit

Permalink
[python] update_uns feature
Browse files Browse the repository at this point in the history
  • Loading branch information
johnkerl committed Oct 20, 2023
1 parent bbe7d68 commit 3dbb7db
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 10 deletions.
2 changes: 2 additions & 0 deletions apis/python/src/tiledbsoma/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
to_anndata,
to_h5ad,
update_obs,
update_uns,
update_var,
)

Expand All @@ -30,4 +31,5 @@
"to_h5ad",
"update_obs",
"update_var",
"update_uns",
)
73 changes: 63 additions & 10 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,8 +2285,52 @@ def _chunk_is_contained_in_axis(
return True


def update_uns(
exp: Experiment,
new_data: Mapping[str, object],
measurement_name: str,
*,
context: Optional[SOMATileDBContext] = None,
platform_config: Optional[PlatformConfig] = None,
) -> None:
"""
Re-writes ``uns`` data to an existing SOMA experiment.
Lifecycle:
Experimental.
"""
if exp.closed or exp.mode != "w":
raise SOMAError(f"Experiment must be open for write: {exp.uri}")
if measurement_name not in exp.ms:
raise SOMAError(
f"Experiment {exp.uri} has no measurement named {measurement_name}"
)
measurement = exp.ms[measurement_name]

# For local disk and S3, creation and storage URIs are identical. For
# cloud, creation URIs look like tiledb://namespace/s3://bucket/path/to/obj
# whereas storage URIs (for the same object) look like
# tiledb://namespace/uuid. When the caller passes a creation URI (which
# they must) via exp.uri, we need to follow that.
if measurement.uri.startswith("tiledb://"):
meas_uri = f"{exp.uri}/ms/{measurement_name}"
else:
meas_uri = measurement.uri

_ingest_uns_dict(
measurement,
meas_uri,
"uns",
new_data,
platform_config=platform_config,
context=context,
ingestion_params=IngestionParams("update", label_mapping=None),
use_relative_uri=None,
)


def _maybe_ingest_uns(
m: Measurement,
measurement: Measurement,
uns: Mapping[str, object],
*,
platform_config: Optional[PlatformConfig],
Expand All @@ -2298,7 +2342,8 @@ def _maybe_ingest_uns(
if not uns:
return
_ingest_uns_dict(
m,
measurement,
measurement.uri,
"uns",
uns,
platform_config=platform_config,
Expand All @@ -2310,6 +2355,7 @@ def _maybe_ingest_uns(

def _ingest_uns_dict(
parent: AnyTileDBCollection,
parent_uri: str,
parent_key: str,
dct: Mapping[str, object],
*,
Expand All @@ -2320,7 +2366,7 @@ def _ingest_uns_dict(
) -> None:
with _create_or_open_collection(
Collection,
_util.uri_joinpath(parent.uri, parent_key),
_util.uri_joinpath(parent_uri, parent_key),
ingestion_params=ingestion_params,
context=context,
) as coll:
Expand All @@ -2338,7 +2384,7 @@ def _ingest_uns_dict(
)

msg = f"Wrote {coll.uri} (uns collection)"
logging.log_io(msg, msg)
logging.log_io_same(msg)


def _ingest_uns_node(
Expand All @@ -2365,6 +2411,7 @@ def _ingest_uns_node(
# Mappings are represented as sub-dictionaries.
_ingest_uns_dict(
coll,
coll.uri,
key,
value,
platform_config=platform_config,
Expand Down Expand Up @@ -2394,7 +2441,7 @@ def _ingest_uns_node(
if value.dtype.names is not None:
msg = f"Skipped {coll.uri}[{key!r}]" " (uns): unsupported structured array"
# This is a structured array, which we do not support.
logging.log_io(msg, msg)
logging.log_io_same(msg)
return

if value.dtype.char in ("U", "O"):
Expand Down Expand Up @@ -2424,7 +2471,7 @@ def _ingest_uns_node(
msg = (
f"Skipped {coll.uri}[{key!r}]" f" (uns object): unrecognized type {type(value)}"
)
logging.log_io(msg, msg)
logging.log_io_same(msg)


def _ingest_uns_string_array(
Expand Down Expand Up @@ -2455,7 +2502,7 @@ def _ingest_uns_string_array(
f"Skipped {coll.uri}[{key!r}]"
f" (uns object): string array is neither one-dimensional nor two-dimensional"
)
logging.log_io(msg, msg)
logging.log_io_same(msg)
return

helper(
Expand Down Expand Up @@ -2565,10 +2612,11 @@ def _ingest_uns_ndarray(
ingestion_params: IngestionParams,
) -> None:
arr_uri = _util.uri_joinpath(coll.uri, key)
print("ARR_URI", arr_uri)

if any(e <= 0 for e in value.shape):
msg = f"Skipped {arr_uri} (uns ndarray): zero in shape {value.shape}"
logging.log_io(msg, msg)
logging.log_io_same(msg)
return

try:
Expand All @@ -2578,10 +2626,15 @@ def _ingest_uns_ndarray(
f"Skipped {arr_uri} (uns ndarray):"
f" unsupported dtype {value.dtype!r} ({value.dtype})"
)
logging.log_io(msg, msg)
logging.log_io_same(msg)
return

try:
soma_arr = _factory.open(arr_uri, "w", soma_type=DenseNDArray, context=context)
# TODO
# old_shape = soma_arr.shape
# new_shape = value.shape
# if old_shape != new_shape:
except DoesNotExistError:
soma_arr = DenseNDArray.create(
arr_uri,
Expand Down Expand Up @@ -2611,7 +2664,7 @@ def _ingest_uns_ndarray(
platform_config=platform_config,
)
msg = f"Wrote {soma_arr.uri} (uns ndarray)"
logging.log_io(msg, msg)
logging.log_io_same(msg)


# ----------------------------------------------------------------
Expand Down

0 comments on commit 3dbb7db

Please sign in to comment.