From 5cda022c5da800453653ea4a0a5a4390cc18b9e4 Mon Sep 17 00:00:00 2001 From: Trent Smith <1429913+Bento007@users.noreply.github.com> Date: Tue, 21 Nov 2023 13:45:39 -0800 Subject: [PATCH] fix: restrict allow categorical types (#691) ## Reason for Change - during schema migration 4.0 additional categorical type combinations were found to be invalid. This PR add additional test coverage and check for those error types. - related to https://github.com/chanzuckerberg/single-cell-curation/issues/405 ## Changes - add check for mixedd categorical types - add check for illegal categorical types - Add test to verify categorical types allowed. ## Testing - unit tests. ## Notes for Reviewer --- .../cellxgene_schema/validate.py | 18 ++++- cellxgene_schema_cli/tests/test_validate.py | 80 ++++++++++++++----- 2 files changed, 76 insertions(+), 22 deletions(-) diff --git a/cellxgene_schema_cli/cellxgene_schema/validate.py b/cellxgene_schema_cli/cellxgene_schema/validate.py index c8d6b14d5..c19d0581e 100644 --- a/cellxgene_schema_cli/cellxgene_schema/validate.py +++ b/cellxgene_schema_cli/cellxgene_schema/validate.py @@ -707,12 +707,22 @@ def _validate_dataframe(self, df_name: str): f"Column '{column_name}' in dataframe '{df_name}' contains a category '{category}' with " f"zero observations. These categories will be removed when `--add-labels` flag is present." ) - # Check for columns that have none string categories, which is not supported by anndata 0.8.0 + categorical_types = {type(x) for x in column.dtype.categories.values} + # Check for columns that have illegal categories, which are not supported by anndata 0.8.0 # TODO: check if this can be removed after upgading to anndata 0.10.0 - category_types = {type(x) for x in column.dtype.categories.values} - if len(category_types) > 1 or str not in category_types: + blocked_categorical_types = {bool} + illegal_categorical_types = categorical_types & blocked_categorical_types + if illegal_categorical_types: self.errors.append( - f"Column '{column_name}' in dataframe '{df_name}' must only contain string categories. Found {category_types}." + f"Column '{column_name}' in dataframe '{df_name}' contains {illegal_categorical_types=}." + ) + # Check for categorical column has mixed types, which is not supported by anndata 0.8.0 + # TODO: check if this can be removed after upgading to anndata 0.10.0 + categorical_types = {type(x) for x in column.dtype.categories.values} + if len(categorical_types) > 1: + self.errors.append( + f"Column '{column_name}' in dataframe '{df_name}' contains {len(categorical_types)} categorical types. " + f"Only one type is allowed." ) # Validate columns diff --git a/cellxgene_schema_cli/tests/test_validate.py b/cellxgene_schema_cli/tests/test_validate.py index c299abfa7..b5b266637 100644 --- a/cellxgene_schema_cli/tests/test_validate.py +++ b/cellxgene_schema_cli/tests/test_validate.py @@ -40,7 +40,7 @@ def schema_def(): @pytest.fixture() def validator_with_minimal_adata(): validator = Validator() - validator.adata = adata_minimal + validator.adata = adata_minimal.copy() return validator @@ -52,6 +52,11 @@ def label_writer(): return AnnDataLabelAppender(validator) +@pytest.fixture +def valid_adata(): + return adata_valid.copy() + + class TestFieldValidation: def test_schema_definition(self, schema_def): """ @@ -355,35 +360,74 @@ def test_determine_seurat_convertibility(self): class TestValidatorValidateDataFrame: - def test_fail_category_not_string(self, tmp_path): - validator = Validator() - validator._set_schema_def() - adata = adata_valid.copy() - t = pd.CategoricalDtype(categories=[True, False]) - adata.obs["not_string"] = pd.Series(data=[True, False], index=["X", "Y"], dtype=t) - validator.adata = adata + @pytest.mark.parametrize("_type", [np.int64, np.int32, int, np.float64, np.float32, float, str]) + def test_succeed_categorical_types(self, tmp_path, _type, valid_adata): + # Arrange + categories = [*map(_type, range(adata_valid.n_obs))] + self._add_catagorical_obs(valid_adata, categories) + validator = self._create_validator(valid_adata) + + # Act + validator._validate_dataframe("obs") + # Assert + assert not validator.errors + valid_adata.write_h5ad(f"{tmp_path}/test.h5ad") # Succeed write + + def test_fail_categorical_mixed_types(self, tmp_path, valid_adata): + # Arrange + categories = ["hello", 123] + self._add_catagorical_obs(valid_adata, categories) + validator = self._create_validator(valid_adata) + + # Act validator._validate_dataframe("obs") - assert "must only contain string categories." in validator.errors[0] - with pytest.raises(TypeError): - # If this tests starts to fail here it means the anndata version has be upgraded and this check is no - # longer needed - adata.write_h5ad(f"{tmp_path}/test.h5ad") - def test_fail_mixed_column_types(self, tmp_path): + # Assert + assert "in dataframe 'obs' contains 2 categorical types. Only one type is allowed." in validator.errors[0] + self._fail_write_h5ad(tmp_path, valid_adata) + + def test_fail_categorical_bool(self, tmp_path, valid_adata): + # Arrange + categories = [True, False] + self._add_catagorical_obs(valid_adata, categories) + validator = self._create_validator(valid_adata) + + # Act + validator._validate_dataframe("obs") + + # Assert + assert "in dataframe 'obs' contains illegal_categorical_types={}." in validator.errors[0] + self._fail_write_h5ad(tmp_path, valid_adata) + + def _add_catagorical_obs(self, adata, categories): + t = pd.CategoricalDtype(categories=categories) + adata.obs["test_cat"] = pd.Series(data=categories, index=["X", "Y"], dtype=t) + + def _create_validator(self, adata): validator = Validator() validator._set_schema_def() - adata = adata_valid.copy() - adata.obs["mixed"] = pd.Series(data=["1234", 0], index=["X", "Y"]) validator.adata = adata + return validator - validator._validate_dataframe("obs") - assert "Column 'mixed' in dataframe 'obs' cannot contain mixed types." in validator.errors[0] + def _fail_write_h5ad(self, tmp_path, adata): with pytest.raises(TypeError): # If this tests starts to fail here it means the anndata version has be upgraded and this check is no # longer needed adata.write_h5ad(f"{tmp_path}/test.h5ad") + def test_fail_mixed_column_types(self, tmp_path, valid_adata): + # Arrange + valid_adata.obs["mixed"] = pd.Series(data=["1234", 0], index=["X", "Y"]) + validator = self._create_validator(valid_adata) + + # Act + validator._validate_dataframe("obs") + + # Assert + assert "in dataframe 'obs' cannot contain mixed types." in validator.errors[0] + self._fail_write_h5ad(tmp_path, valid_adata) + class TestIsRaw: @staticmethod