Skip to content

Commit

Permalink
fix: restrict allow categorical types (#691)
Browse files Browse the repository at this point in the history
## Reason for Change

- during schema migration 4.0 additional categorical type combinations were found to be invalid. This PR add additional test coverage and check for those error types.
- related to #405
## Changes

- add check for mixedd categorical types
- add check for illegal categorical types
- Add test to verify categorical types allowed.

## Testing

- unit tests.

## Notes for Reviewer
  • Loading branch information
Bento007 authored Nov 21, 2023
1 parent 5260fcb commit 5cda022
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 22 deletions.
18 changes: 14 additions & 4 deletions cellxgene_schema_cli/cellxgene_schema/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,22 @@ def _validate_dataframe(self, df_name: str):
f"Column '{column_name}' in dataframe '{df_name}' contains a category '{category}' with "
f"zero observations. These categories will be removed when `--add-labels` flag is present."
)
# Check for columns that have none string categories, which is not supported by anndata 0.8.0
categorical_types = {type(x) for x in column.dtype.categories.values}
# Check for columns that have illegal categories, which are not supported by anndata 0.8.0
# TODO: check if this can be removed after upgading to anndata 0.10.0
category_types = {type(x) for x in column.dtype.categories.values}
if len(category_types) > 1 or str not in category_types:
blocked_categorical_types = {bool}
illegal_categorical_types = categorical_types & blocked_categorical_types
if illegal_categorical_types:
self.errors.append(
f"Column '{column_name}' in dataframe '{df_name}' must only contain string categories. Found {category_types}."
f"Column '{column_name}' in dataframe '{df_name}' contains {illegal_categorical_types=}."
)
# Check for categorical column has mixed types, which is not supported by anndata 0.8.0
# TODO: check if this can be removed after upgading to anndata 0.10.0
categorical_types = {type(x) for x in column.dtype.categories.values}
if len(categorical_types) > 1:
self.errors.append(
f"Column '{column_name}' in dataframe '{df_name}' contains {len(categorical_types)} categorical types. "
f"Only one type is allowed."
)

# Validate columns
Expand Down
80 changes: 62 additions & 18 deletions cellxgene_schema_cli/tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def schema_def():
@pytest.fixture()
def validator_with_minimal_adata():
validator = Validator()
validator.adata = adata_minimal
validator.adata = adata_minimal.copy()
return validator


Expand All @@ -52,6 +52,11 @@ def label_writer():
return AnnDataLabelAppender(validator)


@pytest.fixture
def valid_adata():
return adata_valid.copy()


class TestFieldValidation:
def test_schema_definition(self, schema_def):
"""
Expand Down Expand Up @@ -355,35 +360,74 @@ def test_determine_seurat_convertibility(self):


class TestValidatorValidateDataFrame:
def test_fail_category_not_string(self, tmp_path):
validator = Validator()
validator._set_schema_def()
adata = adata_valid.copy()
t = pd.CategoricalDtype(categories=[True, False])
adata.obs["not_string"] = pd.Series(data=[True, False], index=["X", "Y"], dtype=t)
validator.adata = adata
@pytest.mark.parametrize("_type", [np.int64, np.int32, int, np.float64, np.float32, float, str])
def test_succeed_categorical_types(self, tmp_path, _type, valid_adata):
# Arrange
categories = [*map(_type, range(adata_valid.n_obs))]
self._add_catagorical_obs(valid_adata, categories)
validator = self._create_validator(valid_adata)

# Act
validator._validate_dataframe("obs")

# Assert
assert not validator.errors
valid_adata.write_h5ad(f"{tmp_path}/test.h5ad") # Succeed write

def test_fail_categorical_mixed_types(self, tmp_path, valid_adata):
# Arrange
categories = ["hello", 123]
self._add_catagorical_obs(valid_adata, categories)
validator = self._create_validator(valid_adata)

# Act
validator._validate_dataframe("obs")
assert "must only contain string categories." in validator.errors[0]
with pytest.raises(TypeError):
# If this tests starts to fail here it means the anndata version has be upgraded and this check is no
# longer needed
adata.write_h5ad(f"{tmp_path}/test.h5ad")

def test_fail_mixed_column_types(self, tmp_path):
# Assert
assert "in dataframe 'obs' contains 2 categorical types. Only one type is allowed." in validator.errors[0]
self._fail_write_h5ad(tmp_path, valid_adata)

def test_fail_categorical_bool(self, tmp_path, valid_adata):
# Arrange
categories = [True, False]
self._add_catagorical_obs(valid_adata, categories)
validator = self._create_validator(valid_adata)

# Act
validator._validate_dataframe("obs")

# Assert
assert "in dataframe 'obs' contains illegal_categorical_types={<class 'bool'>}." in validator.errors[0]
self._fail_write_h5ad(tmp_path, valid_adata)

def _add_catagorical_obs(self, adata, categories):
t = pd.CategoricalDtype(categories=categories)
adata.obs["test_cat"] = pd.Series(data=categories, index=["X", "Y"], dtype=t)

def _create_validator(self, adata):
validator = Validator()
validator._set_schema_def()
adata = adata_valid.copy()
adata.obs["mixed"] = pd.Series(data=["1234", 0], index=["X", "Y"])
validator.adata = adata
return validator

validator._validate_dataframe("obs")
assert "Column 'mixed' in dataframe 'obs' cannot contain mixed types." in validator.errors[0]
def _fail_write_h5ad(self, tmp_path, adata):
with pytest.raises(TypeError):
# If this tests starts to fail here it means the anndata version has be upgraded and this check is no
# longer needed
adata.write_h5ad(f"{tmp_path}/test.h5ad")

def test_fail_mixed_column_types(self, tmp_path, valid_adata):
# Arrange
valid_adata.obs["mixed"] = pd.Series(data=["1234", 0], index=["X", "Y"])
validator = self._create_validator(valid_adata)

# Act
validator._validate_dataframe("obs")

# Assert
assert "in dataframe 'obs' cannot contain mixed types." in validator.errors[0]
self._fail_write_h5ad(tmp_path, valid_adata)


class TestIsRaw:
@staticmethod
Expand Down

0 comments on commit 5cda022

Please sign in to comment.