diff --git a/docs/release-notes.rst b/docs/release-notes.rst index 96342da8..23329aac 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -41,6 +41,9 @@ Security Features ~~~~~~~~ +* Added dict-like methods, :meth:`Dataset.keys`, :meth:`Dataset.values` and :meth:`Dataset.items`. +* Added dict-like item setter/getter, :meth:`Dataset.__getitem__` and :meth:`Dataset.__setitem__`. + Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/src/scitacean/dataset.py b/src/scitacean/dataset.py index 8fe2ad5c..b8de16db 100644 --- a/src/scitacean/dataset.py +++ b/src/scitacean/dataset.py @@ -507,6 +507,126 @@ def validate(self) -> None: """ self.make_upload_model() + def keys(self) -> Iterable[str]: + """Dict-like keys(names of fields) method. + + Returns + ------- + : + Generator of names of all fields corresponding to ``self.type`` + and other fields that are not ``None``. + + + .. versionadded:: RELEASE_PLACEHOLDER + """ + from itertools import chain + + all_fields = set((field.name for field in self.fields())) + my_fields = set((field.name for field in self.fields(dataset_type=self.type))) + other_fields = all_fields - my_fields + invalid_fields = ( + f_name for f_name in other_fields if getattr(self, f_name) is not None + ) + + return chain(my_fields, invalid_fields) + + def values(self) -> Iterable[Any]: + """Dict-like values(values of fields) method. + + Returns + ------- + : + Generator of values of all fields corresponding to ``self.type`` + and other fields that are not ``None``. + + + .. versionadded:: RELEASE_PLACEHOLDER + """ + return (getattr(self, field_name) for field_name in self.keys()) + + def items(self) -> Iterable[tuple[str, Any]]: + """Dict-like items(name and value pairs of fields) method. + + Returns + ------- + : + Generator of (Name, Value) pairs of all fields + corresponding to ``self.type`` + and other fields that are not ``None``. + + + .. versionadded:: RELEASE_PLACEHOLDER + """ + return ((key, getattr(self, key)) for key in self.keys()) + + @classmethod + def _validate_field_name(cls, field_name: str) -> None: + """Validate ``field_name``. + + If ``field_name`` is a ``name`` of any + :class:`DatasetBase.Field` objects in ``self.fields()``. + + Parameters + ---------- + field_name: + Name of the field to validate. + + Raises + ------ + KeyError + If validation fails. + """ + if field_name not in (field.name for field in cls.fields()): + raise KeyError(f"{field_name} is not a valid field name.") + + def __getitem__(self, field_name: str) -> Any: + """Dict-like get-item method. + + Parameters + ---------- + field_name: + Name of the field to retrieve. + + Returns + ------- + : + Value of the field with the name ``field_name``. + + Raises + ------ + : + :class:`KeyError` if ``field_name`` does not mach any names of fields. + + + .. versionadded:: RELEASE_PLACEHOLDER + """ + self._validate_field_name(field_name) + return getattr(self, field_name) + + def __setitem__(self, field_name: str, field_value: Any) -> None: + """Dict-like set-item method. + + Set the value of the field with name ``field_name`` as ``field_value``. + + Parameters + ---------- + field_name: + Name of the field to set. + + field_value: + Value of the field to set. + + Raises + ------ + : + :class:`KeyError` if ``field_name`` does not mach any names of fields. + + + .. versionadded:: RELEASE_PLACEHOLDER + """ + self._validate_field_name(field_name) + setattr(self, field_name, field_value) + @dataclasses.dataclass class DatablockUploadModels: diff --git a/tests/dataset_test.py b/tests/dataset_test.py index 9a183d0c..876163a1 100644 --- a/tests/dataset_test.py +++ b/tests/dataset_test.py @@ -765,3 +765,116 @@ def test_derive_removes_attachments(initial, attachments): initial.attachments = attachments derived = initial.derive() assert derived.attachments == [] + + +def invalid_field_example(my_type): + if my_type == DatasetType.DERIVED: + return "data_format", "sth_not_None" + elif my_type == DatasetType.RAW: + return "job_log_data", "sth_not_None" + else: + raise ValueError(my_type, " is not valid DatasetType.") + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_keys_per_type(initial: Dataset): + my_names = set( + field.name for field in Dataset._FIELD_SPEC if field.used_by(initial.type) + ) + assert set(initial.keys()) == my_names + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_keys_including_invalid_field(initial): + invalid_name, invalid_value = invalid_field_example(initial.type) + + my_names = set( + field.name for field in Dataset._FIELD_SPEC if field.used_by(initial.type) + ) + assert invalid_name not in my_names + my_names.add(invalid_name) + + setattr(initial, invalid_name, invalid_value) + + assert set(initial.keys()) == my_names + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_values(initial: Dataset): + for key, value in zip(initial.keys(), initial.values()): + assert value == getattr(initial, key) + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_values_with_invalid_field(initial: Dataset): + setattr(initial, *invalid_field_example(initial.type)) + for key, value in zip(initial.keys(), initial.values()): + assert value == getattr(initial, key) + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_items_with_invalid_field(initial: Dataset): + setattr(initial, *invalid_field_example(initial.type)) + for key, value in initial.items(): + assert value == getattr(initial, key) + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_getitem(initial): + assert initial["type"] == initial.type + + +@pytest.mark.parametrize( + ("is_attr", "wrong_field"), ((True, "size"), (False, "OBVIOUSLYWRONGNAME")) +) +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_getitem_wrong_field_raises(initial, is_attr, wrong_field): + # 'size' should be included in the field later. + # It is now excluded because it is ``manual`` field. See issue#151. + assert hasattr(initial, wrong_field) == is_attr + with pytest.raises(KeyError, match=f"{wrong_field} is not a valid field name."): + initial[wrong_field] + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_setitem(initial: Dataset): + import uuid + + sample_comment = uuid.uuid4().hex + assert initial["comment"] != sample_comment + initial["comment"] = sample_comment + assert initial["comment"] == sample_comment + + +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_setitem_invalid_field(initial: Dataset): + # ``__setitem__`` doesn't check if the item is invalid for the current type or not. + invalid_field, invalid_value = invalid_field_example(initial.type) + assert initial[invalid_field] is None + initial[invalid_field] = invalid_value + assert initial[invalid_field] == invalid_value + + +@pytest.mark.parametrize( + ("is_attr", "wrong_field", "wrong_value"), + ((True, "size", 10), (False, "OBVIOUSLYWRONGNAME", "OBVIOUSLYWRONGVALUE")), +) +@given(initial=sst.datasets(for_upload=True)) +@settings(max_examples=10) +def test_dataset_dict_like_setitem_wrong_field_raises( + initial, is_attr, wrong_field, wrong_value +): + # ``manual`` fields such as ``size`` should raise with ``__setitem__``. + # However, it may need more specific error message. + assert hasattr(initial, wrong_field) == is_attr + with pytest.raises(KeyError, match=f"{wrong_field} is not a valid field name."): + initial[wrong_field] = wrong_value