v0.5.0: `consolidate`, compile compatibility and better non-tensor support
This release is packed with new features and performance improvements.
What's new
TensorDict.consolidate
There is now a TensorDict.consolidate
method that will put all the tensors in a single storage. This will greatly speed-up serialization in multiprocessed and distributed settings.
PT2 support
TensorDict common ops (get
, set
, index
, arithmetic ops etc) now work within torch.compile
.
The list of supported operations can be found in test/test_compile.py
. We encourage users to report any graph break caused by tensordict to us, as we are willing to improve the coverage as much as can be.
Python 3.12 support
#807 enables python 3.12 support, a long awaited feature!
Global reduction for mean
, std
and other reduction methods
It is now possible to get the grand average of a tensordict content using tensordict.mean(reduce=True)
.
This applies to mean
, nanmean
, prod
, std
, sum
, nansum
and var
.
from_pytree
and to_pytree
We made it easy to convert a tensordict to a given pytree
structure and build it from any pytree using to_pytree
and from_pytree
. #832
Similarly, conversion to namedtuple
is now made easy thanks to #788.
map_iter
One can now iterate through a TensorDIct batch-dimension and apply a function on a separate process thanks to map_iter
.
This should enable the construction of datasets using TensorDict
, where the preproc step is executed on a separate process. #847
Using flatten and unflatten, flatten_keys and unflatten_keys as context managers
It is not possible to use flatten_keys
and flatten
as context managers (#908, #779):
with tensordict.flatten_keys() as flat_td:
flat_td["flat.key"] = 0
assert td["flat", "key"] == 0
Building a tensordict using keyword arguments
We made it easy to build tensordicts with simple keyword arguments, like a dict
is built in python:
td = TensorDict(a=0, b=1)
assert td["a"] == torch.tensor(0)
assert td["b"] == torch.tensor(1)
The batch_size
is now optional for both tensordict and tensorclasses. #905
Load tensordicts directly on device
Thanks to #769, it is now possible to load a tensordict directly on a destination device (including "meta"
device):
td = TensorDict.load(path, device=device)
New features
- [Feature,Performance]
to(device, pin_memory, num_threads)
by @vmoens in #846 - [Feature] Allow calls to get_mode, get_mean and get_median in case mode, mean or median is not present by @vmoens in #804
- [Feature] Arithmetic ops for tensorclass by @vmoens in #786
- [Feature] Best attempt to densly stack sub-tds when LazyStacked TDS are passed to maybe_dense_stack by @vmoens in #799
- [Feature] Better dtype coverage by @vmoens in #834
- [Feature] Change default interaction types to DETERMINISTIC by @vmoens in #825
- [Feature] DETERMINISTIC interaction mode by @vmoens in #824
- [Feature] Expose call_on_nested to apply and named_apply by @vmoens in #768
- [Feature] Expose stack / cat as class methods by @vmoens in #793
- [Feature] Load tensordicts on device, incl. meta by @vmoens in #769
- [Feature] Make Probabilistic modules aware of CompositeDistributions out_keys by @vmoens in #810
- [Feature] Memory-mapped nested tensors by @vmoens in #618
- [Feature] Multithreaded apply by @vmoens in #844
- [Feature] Multithreaded pin_memory by @vmoens in #845
- [Feature] Support for non tensor data in h5 by @vmoens in #772
- [Feature] TensorDict.consolidate by @vmoens in #814
- [Feature] TensorDict.numpy() by @vmoens in #787
- [Feature] TensorDict.replace by @vmoens in #774
- [Feature]
out
argument in apply by @vmoens in #794 - [Feature]
to
for consolidated TDs by @vmoens in #851 - [Feature]
zero_grad
andrequires_grad_
by @vmoens in #901 - [Feature] add_custom_mapping and NPE refactors by @vmoens in #910
- [Feature] construct tds with kwargs by @vmoens in #905
- [Feature] determinstic_sample for composite dist by @vmoens in #827
- [Feature] expand_as by @vmoens in #792
- [Feature] flatten and unflatten as decorators by @vmoens in #779
- [Feature] from and to_pytree by @vmoens in #832
- [Feature] from_modules expand_identical kwarg by @vmoens in #911
- [Feature] grad and data for tensorclasses by @vmoens in #904
- [Feature] isfinite, isnan, isreal by @vmoens in #829
- [Feature] map_iter by @vmoens in #847
- [Feature] map_names for composite dists by @vmoens in #809
- [Feature] online edition of memory mapped tensordicts by @vmoens in #775
- [Feature] remove distutils dependency and enable 3.12 support by @GaetanLepage in #807
- [Feature] to_namedtuple and from_namedtuple by @vmoens in #788
- [Feature] view(dtype) by @vmoens in #835
Performance
- [Performance] Faster getattr in TC by @vmoens in #912
- [Performance] Faster lock_/unclock_ when sub-tds are already locked by @vmoens in #816
- [Performance] Faster multithreaded pin_memory by @vmoens in #919
- [Performance] Faster tensorclass by @vmoens in #791
- [Performance] Faster tensorclass set by @vmoens in #880
- [Performance] Faster to-module by @vmoens in #914
Bug Fixes
- [BugFix,CI] Fix storage filename tests by @vmoens in #850
- [BugFix] @Property setter in tensorclass by @vmoens in #813
- [BugFix] Allow any tensorclass to have a data field by @vmoens in #906
- [BugFix] Allow fake-tensor detection pass through in torch 2.0 by @vmoens in #802
- [BugFix] Avoid collapsing NonTensorStack when calling where by @vmoens in #837
- [BugFix] Check if the current user has write access by @MateuszGuzek in #781
- [BugFix] Ensure dtype is preserved with autocast by @vmoens in #773
- [BugFix] FIx non-tensor writing in modules by @vmoens in #822
- [BugFix] Fix (keys, values) in sub by @vmoens in #907
- [BugFix] Fix
_make_dtype_promotion
backward compat by @vmoens in #842 - [BugFix] Fix
pad_sequence
behavior for non-tensor attributes of tensorclass by @kurtamohler in #884 - [BugFix] Fix builds by @vmoens in #849
- [BugFix] Fix compile + vmap by @vmoens in #924
- [BugFix] Fix deterministic fallback when the dist has no support by @vmoens in #830
- [BugFix] Fix device parsing in augmented funcs by @vmoens in #770
- [BugFix] Fix empty tuple index by @vmoens in #811
- [BugFix] Fix fallback of deterministic samples when mean is not available by @vmoens in #828
- [BugFix] Fix functorch dim mock by @vmoens in #777
- [BugFix] Fix gather device by @vmoens in #815
- [BugFix] Fix h5 auto batch size by @vmoens in #798
- [BugFix] Fix key ordering in pointwise ops by @vmoens in #855
- [BugFix] Fix lazy stack features (where and norm) by @vmoens in #795
- [BugFix] Fix map by @vmoens in #862
- [BugFix] Fix map test with fork on cuda by @vmoens in #765
- [BugFix] Fix pad_sequence for non tensors by @vmoens in #784
- [BugFix] Fix setting non-tensors as data in NonTensorData by @vmoens in #864
- [BugFix] Fix stack of tensorclasses (and nontensors) by @vmoens in #820
- [BugFix] Fix storage.filename compat with torch 2.0 by @vmoens in #803
- [BugFix] Fix tensorclass register by @vmoens in #817
- [BugFix] Fix torch version assertion by @vmoens in #917
- [BugFix] Fix vmap compatibility with torch<2.2 by @vmoens in #925
- [BugFix] Fix vmap for tensorclass by @vmoens in #778
- [BugFix] Fix wheels by @vmoens in #856
- [BugFix] Keep stack dim name in LazyStackedTensorDict copy ops by @vmoens in #801
- [BugFix] Read-only compatibility in MemoryMappedTensor by @vmoens in #780
- [BugFix] Refactor map and map_iter by @vmoens in #869
- [BugFix] Sync cuda only if initialized by @vmoens in #767
- [BugFix] fix _expand_to_match_shape for single bool tensor by @vmoens in #902
- [BugFix] fix construction of lazy stacks from tds by @vmoens in #903
- [BugFix] fix tensorclass set by @vmoens in #854
- [BugFix] remove inplace updates when using td as a decorator by @vmoens in #796
- [BugFix] use as_subclass in Buffer by @vmoens in #913
Refactoring and code quality
- [Quality] Better nested detection in numpy() by @vmoens in #800
- [Quality] Better repr of keys by @vmoens in #897
- [Quality] fix c++ binaries formatting by @vmoens in #859
- [Quality] non_blocking_pin instead of pin_memory by @vmoens in #915
- [Quality] zip-strict when possible by @vmoens in #886
- [Refactor] Better tensorclass method registration by @vmoens in #797
- [Refactor] Make all leaves in tensorclass part of
_tensordict
, except for NonTensorData by @vmoens in #841 - [Refactor] Refactor c++ binaries location by @vmoens in #860
- [Refactor] Refactor is_dynamo_compile imports by @vmoens in #916
- [Refactor] Remove
_run_checks
from__init__
by @vmoens in #843 - [Refactor] use from_file instead of mmap+from_buffer for readonly files by @vmoens in #808
Others
- Bump jinja2 from 3.1.3 to 3.1.4 in /docs by @dependabot in #840
- [Benchmark] Benchmark tensorclass ops by @vmoens in #790
- [Benchmark] Fix recursion and cache errors in benchmarks by @vmoens in #900
- [CI] Fix nightly build by @vmoens in #861
- [CI] Python 3.12 compatibility by @kurtamohler in #818
- [Doc] Fix symbolic trace reference in doc by @vmoens in #918
- [Formatting] Lint revamp by @vmoens in #890
- [Test] Test FC of memmap save and load by @vmoens in #838
- [Versioning] Allow any torch version for local builds by @vmoens in #764
- [Versioning] Make dependence on uint16 optional for older PT versions by @vmoens in #839
- [Versioning] tree_leaves for pytorch < 2.3 by @vmoens in #806
- [Versioning] v0.5 bump by @vmoens in #848
New Contributors
- @MateuszGuzek made their first contribution in #781
- @GaetanLepage made their first contribution in #807
- @kurtamohler made their first contribution in #818
Full Changelog: v0.4.0...v0.5.0