-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
meta_tensor.py
610 lines (536 loc) · 26.9 KB
/
meta_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any
import numpy as np
import torch
import monai
from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor
__all__ = ["MetaTensor"]
@functools.lru_cache(None)
def _get_named_tuple_like_type(func):
if (
hasattr(torch, "return_types")
and hasattr(func, "__name__")
and hasattr(torch.return_types, func.__name__)
and isinstance(getattr(torch.return_types, func.__name__), type)
):
return getattr(torch.return_types, func.__name__)
return None
def _not_requiring_metadata(ret):
return isinstance(ret, (int, str, bytes, torch.Size, torch.dtype, torch.device, np.ndarray)) or not (
isinstance(ret, MetaTensor) or (isinstance(ret, Sequence) and any(isinstance(x, MetaTensor) for x in ret))
)
class MetaTensor(MetaObj, torch.Tensor):
"""
Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.
Metadata is stored in the form of a dictionary. Nested, an affine matrix will be
stored. This should be in the form of `torch.Tensor`.
Behavior should be the same as `torch.Tensor` aside from the extended
meta functionality.
Copying of information:
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaTensor` if `a.is_batch` is False
(For batched data, the metadata will be shallow copied for efficiency purposes).
Example:
.. code-block:: python
import torch
from monai.data import MetaTensor
t = torch.tensor([1,2,3])
affine = torch.as_tensor([[2,0,0,0],
[0,2,0,0],
[0,0,2,0],
[0,0,0,1]], dtype=torch.float64)
meta = {"some": "info"}
m = MetaTensor(t, affine=affine, meta=meta)
m2 = m + m
assert isinstance(m2, MetaTensor)
assert m2.meta["some"] == "info"
assert torch.all(m2.affine == affine)
Notes:
- Requires pytorch 1.9 or newer for full compatibility.
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.
- For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor.
see: https://github.com/pytorch/pytorch/issues/54457
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
- You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.
- With a batch of data, `batch[0]` will return the 0th image
with the 0th metadata. When the batch dimension is non-singleton, e.g.,
`batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the
last example) of the metadata will be returned, and `is_batch` will return `True`.
- When creating a batch with this class, use `monai.data.DataLoader` as opposed
to `torch.utils.data.DataLoader`, as this will take care of collating the
metadata properly.
"""
@staticmethod
def __new__(
cls,
x,
affine: torch.Tensor | None = None,
meta: dict | None = None,
applied_operations: list | None = None,
*args,
**kwargs,
) -> MetaTensor:
_kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {}
return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)
def __init__(
self,
x,
affine: torch.Tensor | None = None,
meta: dict | None = None,
applied_operations: list | None = None,
*_args,
**_kwargs,
) -> None:
"""
Args:
x: initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.
affine: optional 4x4 array.
meta: dictionary of metadata.
applied_operations: list of previously applied operations on the MetaTensor,
the list is typically maintained by `monai.transforms.TraceableTransform`.
See also: :py:class:`monai.transforms.TraceableTransform`
_args: additional args (currently not in use in this constructor).
_kwargs: additional kwargs (currently not in use in this constructor).
Note:
If a `meta` dictionary is given, use it. Else, if `meta` exists in the input tensor `x`, use it.
Else, use the default value. Similar for the affine, except this could come from
four places, priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
"""
super().__init__()
# set meta
if meta is not None:
self.meta = meta
elif isinstance(x, MetaObj):
self.__dict__ = deepcopy(x.__dict__)
# set the affine
if affine is not None:
if MetaKeys.AFFINE in self.meta:
warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.")
self.affine = affine
elif MetaKeys.AFFINE in self.meta:
# by using the setter function, we ensure it is converted to torch.Tensor if not already
self.affine = self.meta[MetaKeys.AFFINE]
else:
self.affine = self.get_default_affine()
# applied_operations
if applied_operations is not None:
self.applied_operations = applied_operations
else:
self.applied_operations = MetaObj.get_default_applied_operations()
# if we are creating a new MetaTensor, then deep copy attributes
if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):
self.copy_meta_from(self)
if MetaKeys.SPACE not in self.meta:
self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space
@staticmethod
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
"""
Update the metadata from the output of `MetaTensor.__torch_function__`.
The output of `torch.Tensor.__torch_function__` could be a single object or a
sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
list of not already, and then we loop across each element, processing metadata
as necessary. For each element, if not of type `MetaTensor`, then nothing to do.
Args:
rets: the output from `torch.Tensor.__torch_function__`, which has been
converted to a list in `MetaTensor.__torch_function__` if it wasn't
already a `Sequence`.
func: the torch function that was applied. Examples might be `torch.squeeze`
or `torch.Tensor.__add__`. We need this since the metadata need to be
treated differently if a batch of data is considered. For example,
slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
dimension of a batch of data should return a ith tensor with the ith
metadata.
args: positional arguments that were passed to `func`.
kwargs: keyword arguments that were passed to `func`.
Returns:
A sequence with the same number of elements as `rets`. For each element, if
the input type was not `MetaTensor`, then no modifications will have been
made. If global parameters have been set to false (e.g.,
`not get_track_meta()`), then any `MetaTensor` will be converted to
`torch.Tensor`. Else, metadata will be propagated as necessary (see
:py:func:`MetaTensor._copy_meta`).
"""
out = []
metas = None # optional output metadicts for each of the return value in `rets`
is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch"))
for idx, ret in enumerate(rets):
# if not `MetaTensor`, nothing to do.
if not isinstance(ret, MetaTensor):
pass
# if not tracking, convert to `torch.Tensor`.
elif not get_track_meta():
ret = ret.as_tensor()
# else, handle the `MetaTensor` metadata.
else:
meta_args = MetaObj.flatten_meta_objs(args, kwargs.values())
ret.is_batch = is_batch
ret.copy_meta_from(meta_args, copy_attr=not is_batch)
# the following is not implemented but the network arch may run into this case:
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
if is_batch:
ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)
out.append(ret)
# if the input was a tuple, then return it as a tuple
return tuple(out) if isinstance(rets, tuple) else out
@classmethod
def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
"""utility function to handle batched MetaTensors."""
# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
if idx > 0 or len(args) < 2 or len(args[0]) < 1:
return ret
batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor):
return ret
dec_batch = decollate_batch(args[0], detach=False)
ret_meta = dec_batch[batch_idx]
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
"please consider converting it into a torch Tensor using `x.as_tensor()` or "
"a numpy array using `x.array`."
) from e
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
ret_meta.is_batch = False
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
if metas is None:
metas = decollate_batch(args[0], detach=False)
if hasattr(metas[idx], "__dict__"):
ret.__dict__ = metas[idx].__dict__.copy()
ret.is_batch = False
return ret
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
"""Wraps all torch functions."""
if kwargs is None:
kwargs = {}
ret = super().__torch_function__(func, types, args, kwargs)
# if `out` has been used as argument, metadata is not copied, nothing to do.
# if "out" in kwargs:
# return ret
if _not_requiring_metadata(ret):
return ret
if _get_named_tuple_like_type(func) is not None and isinstance(ret, _get_named_tuple_like_type(func)):
# for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like
out_items = MetaTensor.update_meta(ret, func, args, kwargs)
for idx in range(ret.n_fields):
ret[idx].meta = out_items[idx].meta
ret[idx].applied_operations = out_items[idx].applied_operations
return ret
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
# else (e.g., `__repr__` returns a string).
# Convert to list (if necessary), process, and at end remove list if one was added.
if not isinstance(ret, Sequence):
ret = [ret]
unpack = True
else:
unpack = False
ret = MetaTensor.update_meta(ret, func, args, kwargs)
return ret[0] if unpack else ret
@staticmethod
def _convert(x):
if isinstance(x, (MetaTensor, torch.Tensor, tuple, list)):
return convert_data_type(x, output_type=np.ndarray, wrap_sequence=False)[0]
return x
def __array_function__(self, func, types, args, kwargs):
"""for numpy Interoperability, so that we can compute ``np.sum(MetaTensor([1.0]))``."""
try:
if not func.__module__.startswith("numpy"):
return NotImplemented
except AttributeError:
return NotImplemented
_args = list(map(MetaTensor._convert, args))
_kwargs = {k: MetaTensor._convert(v) for k, v in kwargs.items()}
return func(*_args, **_kwargs)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""
For numpy interoperability, so that we can compute ``MetaTensor([1.0]) >= np.asarray([1.0])``.
This is for pytorch > 1.8.
"""
try:
if not type(ufunc).__module__.startswith("numpy"):
return NotImplemented
except AttributeError:
return NotImplemented
if method != "__call__":
return NotImplemented
_inputs = map(MetaTensor._convert, inputs)
_kwargs = {k: MetaTensor._convert(v) for k, v in kwargs.items()}
if "out" in _kwargs:
return NotImplemented # not supported
try:
return getattr(ufunc, method)(*_inputs, **_kwargs)
except AttributeError:
return NotImplemented
@staticmethod
def get_default_affine(dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=torch.device("cpu"), dtype=dtype)
def as_tensor(self) -> torch.Tensor:
"""
Return the `MetaTensor` as a `torch.Tensor`.
It is OS dependent as to whether this will be a deep copy or not.
"""
return self.as_subclass(torch.Tensor)
def get_array(self, output_type=np.ndarray, dtype=None, device=None, *_args, **_kwargs):
"""
Returns a new array in `output_type`, the array shares the same underlying storage when the output is a
numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.
Args:
output_type: output type, see also: :py:func:`monai.utils.convert_data_type`.
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
device: if the output is a `torch.Tensor`, select device (if `None`, unchanged).
_args: currently unused parameters.
_kwargs: currently unused parameters.
"""
return convert_data_type(self, output_type=output_type, dtype=dtype, device=device, wrap_sequence=True)[0]
def set_array(self, src, non_blocking: bool = False, *_args, **_kwargs):
"""
Copies the elements from src into self tensor and returns self.
The src tensor must be broadcastable with the self tensor.
It may be of a different data type or reside on a different device.
See also: `https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html`
Args:
src: the source tensor to copy from.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur
asynchronously with respect to the host. For other cases, this argument has no effect.
_args: currently unused parameters.
_kwargs: currently unused parameters.
"""
converted: torch.Tensor = convert_to_tensor(src, track_meta=False, wrap_sequence=True)
try:
return self.copy_(converted, non_blocking=non_blocking)
except RuntimeError: # skip the shape checking
self.data = converted
return self
@property
def array(self):
"""
Returns a numpy array of ``self``. The array and ``self`` shares the same underlying storage if self is on cpu.
Changes to ``self`` (it's a subclass of torch.Tensor) will be reflected in the ndarray and vice versa.
If ``self`` is not on cpu, the call will move the array to cpu and then the storage is not shared.
:getter: see also: :py:func:`MetaTensor.get_array()`
:setter: see also: :py:func:`MetaTensor.set_array()`
"""
return self.get_array()
@array.setter
def array(self, src) -> None:
"""A default setter using ``self.set_array()``"""
self.set_array(src)
def as_dict(self, key: str, output_type=torch.Tensor, dtype=None) -> dict:
"""
Get the object as a dictionary for backwards compatibility.
This method does not make a deep copy of the objects.
Args:
key: Base key to store main data. The key for the metadata will be determined using `PostFix`.
output_type: `torch.Tensor` or `np.ndarray` for the main data.
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
Return:
A dictionary consisting of three keys, the main data (stored under `key`) and the metadata.
"""
if output_type not in (torch.Tensor, np.ndarray):
raise ValueError(f"output_type must be torch.Tensor or np.ndarray, got {output_type}.")
return {
key: self.get_array(output_type=output_type, dtype=dtype),
PostFix.meta(key): self.meta,
PostFix.transforms(key): self.applied_operations,
}
def astype(self, dtype, device=None, *_args, **_kwargs):
"""
Cast to ``dtype``, sharing data whenever possible.
Args:
dtype: dtypes such as np.float32, torch.float, "np.float32", float.
device: the device if `dtype` is a torch data type.
_args: additional args (currently unused).
_kwargs: additional kwargs (currently unused).
Returns:
data array instance
"""
if isinstance(dtype, str):
mod_str, *dtype = dtype.split(".", 1)
dtype = mod_str if not dtype else dtype[0]
else:
mod_str = getattr(dtype, "__module__", "torch")
mod_str = look_up_option(mod_str, {"torch", "numpy", "np"}, default="numpy")
out_type: type[torch.Tensor] | type[np.ndarray] | None
if mod_str == "torch":
out_type = torch.Tensor
elif mod_str in ("numpy", "np"):
out_type = np.ndarray
else:
out_type = None
return self.get_array(output_type=out_type, dtype=dtype, device=device)
@property
def affine(self) -> torch.Tensor:
"""Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``"""
return self.meta.get(MetaKeys.AFFINE, self.get_default_affine()) # type: ignore
@affine.setter
def affine(self, d: NdarrayTensor) -> None:
"""Set the affine."""
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
@property
def pixdim(self):
"""Get the spacing"""
if self.is_batch:
return [affine_to_spacing(a) for a in self.affine]
return affine_to_spacing(self.affine)
def peek_pending_shape(self):
"""
Get the currently expected spatial shape as if all the pending operations are executed.
For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned.
"""
res = None
if self.pending_operations:
res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)
# default to spatial shape (assuming channel-first input)
return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res
def peek_pending_affine(self):
res = self.affine
r = len(res) - 1
if r not in (2, 3):
warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.")
for p in self.pending_operations:
next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE), dtype=torch.float64)
if next_matrix is None:
continue
res = convert_to_dst_type(res, next_matrix)[0]
next_matrix = monai.data.utils.to_affine_nd(r, next_matrix)
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)
return res
def peek_pending_rank(self):
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
return 1 if a is None else int(max(1, len(a) - 1))
def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
"""
must be defined for deepcopy to work
See:
- https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty
"""
return type(self)(
self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad)
)
def clone(self, **kwargs):
"""
Returns a copy of the MetaTensor instance.
Args:
kwargs: additional keyword arguments to `torch.clone`.
See also: https://pytorch.org/docs/stable/generated/torch.clone.html
"""
new_inst = MetaTensor(self.as_tensor().clone(**kwargs))
new_inst.__dict__ = deepcopy(self.__dict__)
return new_inst
@staticmethod
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
):
"""
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
Args:
im: Input image (`np.ndarray` or `torch.Tensor`)
meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor.
simple_keys: whether to keep only a simple subset of metadata keys.
pattern: combined with `sep`, a regular expression used to match and prune keys
in the metadata (nested dictionary), default to None, no key deletion.
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
if meta is None:
meta = {}
# remove any superfluous metadata.
if simple_keys:
# ensure affine is of type `torch.Tensor`
if MetaKeys.AFFINE in meta:
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
remove_extra_metadata(meta) # bc-breaking
if pattern is not None:
meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)
# return the `MetaTensor`
if meta is None:
meta = {}
img.meta = meta
if MetaKeys.AFFINE in meta:
img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter
else:
img.affine = MetaTensor.get_default_affine()
return img
def __repr__(self): # type: ignore[override]
"""
Prints a representation of the tensor.
Prepends "meta" to ``torch.Tensor.__repr__``.
Use ``print_verbose`` for associated metadata.
"""
return f"meta{self.as_tensor().__repr__()}"
def __str__(self):
"""
Prints a representation of the tensor.
Prepends "meta" to ``torch.Tensor.__str__``.
Use ``print_verbose`` for associated metadata.
"""
return f"meta{str(self.as_tensor())}"
def __format__(self, format_spec):
"""
returns the output of pytorch tensor's ``__format__`` method.
"""
return self.as_tensor().__format__(format_spec)
def print_verbose(self) -> None:
"""Verbose print with meta data."""
print(self)
if self.meta is not None:
print(self.meta.__repr__())