diff --git a/brainio_collection/packaging.py b/brainio_collection/packaging.py index 2e3e42f..dd1897f 100644 --- a/brainio_collection/packaging.py +++ b/brainio_collection/packaging.py @@ -5,9 +5,9 @@ import boto3 from tqdm import tqdm -from xarray import DataArray import brainio_base.assemblies +from brainio_base.assemblies import get_levels from brainio_collection import lookup, list_stimulus_sets from brainio_collection.lookup import TYPE_ASSEMBLY, TYPE_STIMULUS_SET, sha1_hash @@ -111,9 +111,7 @@ def package_stimulus_set(proto_stimulus_set, stimulus_set_identifier, bucket_nam def write_netcdf(assembly, target_netcdf_file): _logger.debug(f"Writing assembly to {target_netcdf_file}") - assembly = DataArray(assembly) # if we're passed a BrainIO DataAssembly, it will automatically re-index otherwise - for index in assembly.indexes.keys(): - assembly.reset_index(index, inplace=True) + assembly = assembly.reset_index(list(assembly.indexes)) assembly.to_netcdf(target_netcdf_file) sha1 = sha1_hash(target_netcdf_file) return sha1 diff --git a/tests/test_assemblies.py b/tests/test_assemblies.py index b8481a2..8628ba1 100644 --- a/tests/test_assemblies.py +++ b/tests/test_assemblies.py @@ -239,6 +239,13 @@ def test_aperture(self, identifier, image_id, expected_amount_gray, ratio_gray): assert amount_gray == expected_amount_gray +def test_inplace(): + d = xr.DataArray(0, None, None, None, None, None, False) + with pytest.raises(TypeError) as te: + d = d.reset_index(None, inplace=True) + assert "inplace" in str(te.value) + + class TestSeibert: @pytest.mark.private_access def test_dims(self): diff --git a/tests/test_packaging.py b/tests/test_packaging.py new file mode 100644 index 0000000..67ed7d0 --- /dev/null +++ b/tests/test_packaging.py @@ -0,0 +1,57 @@ +import pytest +from pathlib import Path + +from brainio_base.assemblies import DataAssembly, get_levels +from brainio_collection.packaging import write_netcdf + + +def test_write_netcdf(): + assy = DataAssembly( + data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18]], + coords={ + 'up': ("a", ['alpha', 'alpha', 'beta', 'beta', 'beta', 'beta']), + 'down': ("a", [1, 1, 1, 1, 2, 2]), + 'sideways': ('b', ['x', 'y', 'z']) + }, + dims=['a', 'b'] + ) + netcdf_path = Path("test.nc") + netcdf_sha1 = write_netcdf(assy, str(netcdf_path)) + assert netcdf_path.exists() + + +def test_reset_index(): + assy = DataAssembly( + data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18]], + coords={ + 'up': ("a", ['alpha', 'alpha', 'beta', 'beta', 'beta', 'beta']), + 'down': ("a", [1, 1, 1, 1, 2, 2]), + 'sideways': ('b', ['x', 'y', 'z']) + }, + dims=['a', 'b'] + ) + assert assy["a"].variable.level_names == ["up", "down"] + assert list(assy.indexes) == ["a", "b"] + assy = assy.reset_index(list(assy.indexes)) + assert assy["a"].variable.level_names is None + assert get_levels(assy) == [] + assert list(assy.indexes) == [] + + + +def test_reset_index_levels(): + assy = DataAssembly( + data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18]], + coords={ + 'up': ("a", ['alpha', 'alpha', 'beta', 'beta', 'beta', 'beta']), + 'down': ("a", [1, 1, 1, 1, 2, 2]), + 'sideways': ('b', ['x', 'y', 'z']) + }, + dims=['a', 'b'] + ) + assert assy["a"].variable.level_names == ["up", "down"] + assy = assy.reset_index(["up", "down"]) + assert get_levels(assy) == [] + + +