Skip to content

Commit

Permalink
neaten
Browse files Browse the repository at this point in the history
  • Loading branch information
johnkerl committed Mar 18, 2024
1 parent b2df960 commit c0d60c6
Showing 1 changed file with 58 additions and 60 deletions.
118 changes: 58 additions & 60 deletions apis/python/tests/test_registration_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _create_anndata(
var_ids: Sequence[str],
obs_field_name: str,
var_field_name: str,
X_base: int,
X_value_base: int,
measurement_name: str,
raw_var_ids: Optional[Sequence[str]] = None,
X_density: float = 0.3,
Expand Down Expand Up @@ -56,22 +56,22 @@ def _make_var(arg_var_ids):
var.set_index(var_field_name, inplace=True)
return var

def _make_X(n_obs, n_var, X_base):
def _make_X(n_obs, n_var, X_value_base):
X = np.zeros((n_obs, n_var))
for i in range(n_obs):
for j in range(n_var):
if (i + j) % 2 == 1:
X[i, j] = X_base + 10 * i + j
X[i, j] = X_value_base + 10 * i + j
return X

var = _make_var(var_ids)
X = _make_X(n_obs, n_var, X_base)
X = _make_X(n_obs, n_var, X_value_base)

adata = ad.AnnData(X=X, obs=obs, var=var, dtype=X.dtype)

if raw_var_ids is not None:
raw_var = _make_var(raw_var_ids)
raw_X = _make_X(n_obs, len(raw_var_ids), X_base)
raw_X = _make_X(n_obs, len(raw_var_ids), X_value_base)
raw = ad.Raw(adata, var=raw_var, X=raw_X)
adata = ad.AnnData(X=X, obs=obs, var=var, dtype=X.dtype, raw=raw)

Expand All @@ -87,66 +87,60 @@ def create_h5ad(adata, path):
# datasets, varying cell and gene IDs.
def create_anndata_canned(which: int, obs_field_name: str, var_field_name: str):
if which == 1:
return _create_anndata(
obs_ids=["AAAT", "ACTG", "AGAG"],
var_ids=["AKT1", "APOE", "ESR1", "TP53", "VEGFA"],
raw_var_ids=["AKT1", "APOE", "ESR1", "TP53", "VEGFA", "RAW1", "RAW2"],
X_base=100,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

if which == 2:
return _create_anndata(
obs_ids=["CAAT", "CCTG", "CGAG"],
var_ids=["APOE", "ESR1", "TP53", "VEGFA"],
raw_var_ids=["APOE", "ESR1", "TP53", "VEGFA"],
X_base=200,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)

if which == 3:
return _create_anndata(
obs_ids=["GAAT", "GCTG", "GGAG"],
var_ids=["APOE", "EGFR", "ESR1", "TP53", "VEGFA"],
raw_var_ids=["APOE", "EGFR", "ESR1", "TP53", "VEGFA", "RAW1", "RAW3"],
X_base=300,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)
obs_ids = ["AAAT", "ACTG", "AGAG"]
var_ids = ["AKT1", "APOE", "ESR1", "TP53", "VEGFA"]
raw_var_ids = ["AKT1", "APOE", "ESR1", "TP53", "VEGFA", "RAW1", "RAW2"]
X_value_base = 100

elif which == 2:
obs_ids = ["CAAT", "CCTG", "CGAG"]
var_ids = ["APOE", "ESR1", "TP53", "VEGFA"]
raw_var_ids = ["APOE", "ESR1", "TP53", "VEGFA"]
X_value_base = 200

elif which == 3:
obs_ids = ["GAAT", "GCTG", "GGAG"]
var_ids = ["APOE", "EGFR", "ESR1", "TP53", "VEGFA"]
raw_var_ids = ["APOE", "EGFR", "ESR1", "TP53", "VEGFA", "RAW1", "RAW3"]
X_value_base = 300

elif which == 4:
obs_ids = ["TAAT", "TCTG", "TGAG"]
var_ids = ["AKT1", "APOE", "ESR1", "TP53", "VEGFA", "ZZZ3"]
raw_var_ids = [
"AKT1",
"APOE",
"ESR1",
"TP53",
"VEGFA",
"ZZZ3",
"RAW1",
"RAW3",
"RAW2",
]
X_value_base = 400

if which == 4:
return _create_anndata(
obs_ids=["TAAT", "TCTG", "TGAG"],
var_ids=["AKT1", "APOE", "ESR1", "TP53", "VEGFA", "ZZZ3"],
raw_var_ids=[
"AKT1",
"APOE",
"ESR1",
"TP53",
"VEGFA",
"ZZZ3",
"RAW1",
"RAW3",
"RAW2",
],
X_base=400,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)
else:
raise Exception(f"create_anndata_canned takes 1..4; got {which}")

raise Exception(f"create_anndata_canned takes 1..4; got {which}")
return _create_anndata(
obs_ids=obs_ids,
var_ids=var_ids,
raw_var_ids=raw_var_ids,
X_value_base=X_value_base,
measurement_name="measname",
obs_field_name=obs_field_name,
var_field_name=var_field_name,
)


def create_h5ad_canned(which: int, obs_field_name: str, var_field_name: str):
tmp_path = tempfile.TemporaryDirectory()
anndata = create_anndata_canned(which, obs_field_name, var_field_name)
return create_h5ad(anndata, (tmp_path.name + f"{which}.h5ad"))
return create_h5ad(
anndata,
(tmp_path.name + f"{which}.h5ad"),
)


def create_soma_canned(which: int, obs_field_name, var_field_name):
Expand All @@ -162,8 +156,10 @@ def anndata_larger():
return _create_anndata(
obs_ids=["id_%08d" % e for e in range(1000)],
var_ids=["AKT1", "APOE", "ESR1", "TP53", "VEGFA", "ZZZ3"],
X_base=0,
X_value_base=0,
measurement_name="measname",
obs_field_name="cell_id",
var_field_name="gene_id",
)


Expand Down Expand Up @@ -881,7 +877,7 @@ def test_append_with_disjoint_measurements(
tmp_path, obs_field_name, var_field_name, use_same_cells
):
anndata1 = create_anndata_canned(1, obs_field_name, var_field_name)
anndata4 = create_anndata_canned(2, obs_field_name, var_field_name)
anndata4 = create_anndata_canned(4, obs_field_name, var_field_name)
soma_uri = tmp_path.as_posix()

tiledbsoma.io.from_anndata(soma_uri, anndata1, measurement_name="one")
Expand Down Expand Up @@ -1132,6 +1128,8 @@ def test_registration_with_batched_reads(tmp_path, soma_larger, use_small_buffer
rd = registration.ExperimentAmbientLabelMapping.from_isolated_soma_experiment(
soma_larger,
context=context,
obs_field_name="cell_id",
var_field_name="gene_id",
)

assert len(rd.obs_axis.data) == 1000
Expand Down

0 comments on commit c0d60c6

Please sign in to comment.