Skip to content

Commit

Permalink
📝 [array] Add examples for EnsureType and CastToType (Project-MONAI#7245
Browse files Browse the repository at this point in the history
)

Fixes Project-MONAI#7101

### Description

Added examples in the docstrings for `EnsureType` and `CastToType`
transforms which show how they function under different circumstances.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Ishan Dutta <[email protected]>
Signed-off-by: Yu0610 <[email protected]>
  • Loading branch information
ishandutta0098 authored and Yu0610 committed Apr 11, 2024
1 parent c2b2fd6 commit 10dd6fc
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,23 @@ class CastToType(Transform):
"""
Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to
specified PyTorch data type.
Example:
>>> import numpy as np
>>> import torch
>>> transform = CastToType(dtype=np.float32)
>>> # Example with a numpy array
>>> img_np = np.array([0, 127, 255], dtype=np.uint8)
>>> img_np_casted = transform(img_np)
>>> img_np_casted
array([ 0. , 127. , 255. ], dtype=float32)
>>> # Example with a PyTorch tensor
>>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8)
>>> img_tensor_casted = transform(img_tensor)
>>> img_tensor_casted
tensor([ 0., 127., 255.]) # dtype is float32
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand Down Expand Up @@ -413,10 +430,26 @@ class EnsureType(Transform):
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,
if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.
Example with wrap_sequence=True:
>>> import numpy as np
>>> import torch
>>> transform = EnsureType(data_type="tensor", wrap_sequence=True)
>>> # Converting a list to a tensor
>>> data_list = [1, 2., 3]
>>> tensor_data = transform(data_list)
>>> tensor_data
tensor([1., 2., 3.]) # All elements have dtype float32
Example with wrap_sequence=False:
>>> transform = EnsureType(data_type="tensor", wrap_sequence=False)
>>> # Converting each element in a list to individual tensors
>>> data_list = [1, 2, 3]
>>> tensors_list = transform(data_list)
>>> tensors_list
[tensor(1), tensor(2.), tensor(3)] # Only second element is float32 rest are int64
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand Down

0 comments on commit 10dd6fc

Please sign in to comment.