Skip to content

v0.5.0: `consolidate`, compile compatibility and better non-tensor support

Compare
Choose a tag to compare
@vmoens vmoens released this 30 Jul 21:39

This release is packed with new features and performance improvements.

What's new

TensorDict.consolidate

There is now a TensorDict.consolidate method that will put all the tensors in a single storage. This will greatly speed-up serialization in multiprocessed and distributed settings.

PT2 support

TensorDict common ops (get, set, index, arithmetic ops etc) now work within torch.compile.
The list of supported operations can be found in test/test_compile.py. We encourage users to report any graph break caused by tensordict to us, as we are willing to improve the coverage as much as can be.

Python 3.12 support

#807 enables python 3.12 support, a long awaited feature!

Global reduction for mean, std and other reduction methods

It is now possible to get the grand average of a tensordict content using tensordict.mean(reduce=True).
This applies to mean, nanmean, prod, std, sum, nansum and var.

from_pytree and to_pytree

We made it easy to convert a tensordict to a given pytree structure and build it from any pytree using to_pytree and from_pytree. #832
Similarly, conversion to namedtuple is now made easy thanks to #788.

map_iter

One can now iterate through a TensorDIct batch-dimension and apply a function on a separate process thanks to map_iter.
This should enable the construction of datasets using TensorDict, where the preproc step is executed on a separate process. #847

Using flatten and unflatten, flatten_keys and unflatten_keys as context managers

It is not possible to use flatten_keys and flatten as context managers (#908, #779):

with tensordict.flatten_keys() as flat_td:
    flat_td["flat.key"] = 0
assert td["flat", "key"] == 0

Building a tensordict using keyword arguments

We made it easy to build tensordicts with simple keyword arguments, like a dict is built in python:

td = TensorDict(a=0, b=1)
assert td["a"] == torch.tensor(0)
assert td["b"] == torch.tensor(1)

The batch_size is now optional for both tensordict and tensorclasses. #905

Load tensordicts directly on device

Thanks to #769, it is now possible to load a tensordict directly on a destination device (including "meta" device):

td = TensorDict.load(path, device=device)

New features

  • [Feature,Performance] to(device, pin_memory, num_threads) by @vmoens in #846
  • [Feature] Allow calls to get_mode, get_mean and get_median in case mode, mean or median is not present by @vmoens in #804
  • [Feature] Arithmetic ops for tensorclass by @vmoens in #786
  • [Feature] Best attempt to densly stack sub-tds when LazyStacked TDS are passed to maybe_dense_stack by @vmoens in #799
  • [Feature] Better dtype coverage by @vmoens in #834
  • [Feature] Change default interaction types to DETERMINISTIC by @vmoens in #825
  • [Feature] DETERMINISTIC interaction mode by @vmoens in #824
  • [Feature] Expose call_on_nested to apply and named_apply by @vmoens in #768
  • [Feature] Expose stack / cat as class methods by @vmoens in #793
  • [Feature] Load tensordicts on device, incl. meta by @vmoens in #769
  • [Feature] Make Probabilistic modules aware of CompositeDistributions out_keys by @vmoens in #810
  • [Feature] Memory-mapped nested tensors by @vmoens in #618
  • [Feature] Multithreaded apply by @vmoens in #844
  • [Feature] Multithreaded pin_memory by @vmoens in #845
  • [Feature] Support for non tensor data in h5 by @vmoens in #772
  • [Feature] TensorDict.consolidate by @vmoens in #814
  • [Feature] TensorDict.numpy() by @vmoens in #787
  • [Feature] TensorDict.replace by @vmoens in #774
  • [Feature] out argument in apply by @vmoens in #794
  • [Feature] to for consolidated TDs by @vmoens in #851
  • [Feature] zero_grad and requires_grad_ by @vmoens in #901
  • [Feature] add_custom_mapping and NPE refactors by @vmoens in #910
  • [Feature] construct tds with kwargs by @vmoens in #905
  • [Feature] determinstic_sample for composite dist by @vmoens in #827
  • [Feature] expand_as by @vmoens in #792
  • [Feature] flatten and unflatten as decorators by @vmoens in #779
  • [Feature] from and to_pytree by @vmoens in #832
  • [Feature] from_modules expand_identical kwarg by @vmoens in #911
  • [Feature] grad and data for tensorclasses by @vmoens in #904
  • [Feature] isfinite, isnan, isreal by @vmoens in #829
  • [Feature] map_iter by @vmoens in #847
  • [Feature] map_names for composite dists by @vmoens in #809
  • [Feature] online edition of memory mapped tensordicts by @vmoens in #775
  • [Feature] remove distutils dependency and enable 3.12 support by @GaetanLepage in #807
  • [Feature] to_namedtuple and from_namedtuple by @vmoens in #788
  • [Feature] view(dtype) by @vmoens in #835

Performance

  • [Performance] Faster getattr in TC by @vmoens in #912
  • [Performance] Faster lock_/unclock_ when sub-tds are already locked by @vmoens in #816
  • [Performance] Faster multithreaded pin_memory by @vmoens in #919
  • [Performance] Faster tensorclass by @vmoens in #791
  • [Performance] Faster tensorclass set by @vmoens in #880
  • [Performance] Faster to-module by @vmoens in #914

Bug Fixes

  • [BugFix,CI] Fix storage filename tests by @vmoens in #850
  • [BugFix] @Property setter in tensorclass by @vmoens in #813
  • [BugFix] Allow any tensorclass to have a data field by @vmoens in #906
  • [BugFix] Allow fake-tensor detection pass through in torch 2.0 by @vmoens in #802
  • [BugFix] Avoid collapsing NonTensorStack when calling where by @vmoens in #837
  • [BugFix] Check if the current user has write access by @MateuszGuzek in #781
  • [BugFix] Ensure dtype is preserved with autocast by @vmoens in #773
  • [BugFix] FIx non-tensor writing in modules by @vmoens in #822
  • [BugFix] Fix (keys, values) in sub by @vmoens in #907
  • [BugFix] Fix _make_dtype_promotion backward compat by @vmoens in #842
  • [BugFix] Fix pad_sequence behavior for non-tensor attributes of tensorclass by @kurtamohler in #884
  • [BugFix] Fix builds by @vmoens in #849
  • [BugFix] Fix compile + vmap by @vmoens in #924
  • [BugFix] Fix deterministic fallback when the dist has no support by @vmoens in #830
  • [BugFix] Fix device parsing in augmented funcs by @vmoens in #770
  • [BugFix] Fix empty tuple index by @vmoens in #811
  • [BugFix] Fix fallback of deterministic samples when mean is not available by @vmoens in #828
  • [BugFix] Fix functorch dim mock by @vmoens in #777
  • [BugFix] Fix gather device by @vmoens in #815
  • [BugFix] Fix h5 auto batch size by @vmoens in #798
  • [BugFix] Fix key ordering in pointwise ops by @vmoens in #855
  • [BugFix] Fix lazy stack features (where and norm) by @vmoens in #795
  • [BugFix] Fix map by @vmoens in #862
  • [BugFix] Fix map test with fork on cuda by @vmoens in #765
  • [BugFix] Fix pad_sequence for non tensors by @vmoens in #784
  • [BugFix] Fix setting non-tensors as data in NonTensorData by @vmoens in #864
  • [BugFix] Fix stack of tensorclasses (and nontensors) by @vmoens in #820
  • [BugFix] Fix storage.filename compat with torch 2.0 by @vmoens in #803
  • [BugFix] Fix tensorclass register by @vmoens in #817
  • [BugFix] Fix torch version assertion by @vmoens in #917
  • [BugFix] Fix vmap compatibility with torch<2.2 by @vmoens in #925
  • [BugFix] Fix vmap for tensorclass by @vmoens in #778
  • [BugFix] Fix wheels by @vmoens in #856
  • [BugFix] Keep stack dim name in LazyStackedTensorDict copy ops by @vmoens in #801
  • [BugFix] Read-only compatibility in MemoryMappedTensor by @vmoens in #780
  • [BugFix] Refactor map and map_iter by @vmoens in #869
  • [BugFix] Sync cuda only if initialized by @vmoens in #767
  • [BugFix] fix _expand_to_match_shape for single bool tensor by @vmoens in #902
  • [BugFix] fix construction of lazy stacks from tds by @vmoens in #903
  • [BugFix] fix tensorclass set by @vmoens in #854
  • [BugFix] remove inplace updates when using td as a decorator by @vmoens in #796
  • [BugFix] use as_subclass in Buffer by @vmoens in #913

Refactoring and code quality

  • [Quality] Better nested detection in numpy() by @vmoens in #800
  • [Quality] Better repr of keys by @vmoens in #897
  • [Quality] fix c++ binaries formatting by @vmoens in #859
  • [Quality] non_blocking_pin instead of pin_memory by @vmoens in #915
  • [Quality] zip-strict when possible by @vmoens in #886
  • [Refactor] Better tensorclass method registration by @vmoens in #797
  • [Refactor] Make all leaves in tensorclass part of _tensordict, except for NonTensorData by @vmoens in #841
  • [Refactor] Refactor c++ binaries location by @vmoens in #860
  • [Refactor] Refactor is_dynamo_compile imports by @vmoens in #916
  • [Refactor] Remove _run_checks from __init__ by @vmoens in #843
  • [Refactor] use from_file instead of mmap+from_buffer for readonly files by @vmoens in #808

Others

  • Bump jinja2 from 3.1.3 to 3.1.4 in /docs by @dependabot in #840
  • [Benchmark] Benchmark tensorclass ops by @vmoens in #790
  • [Benchmark] Fix recursion and cache errors in benchmarks by @vmoens in #900
  • [CI] Fix nightly build by @vmoens in #861
  • [CI] Python 3.12 compatibility by @kurtamohler in #818
  • [Doc] Fix symbolic trace reference in doc by @vmoens in #918
  • [Formatting] Lint revamp by @vmoens in #890
  • [Test] Test FC of memmap save and load by @vmoens in #838
  • [Versioning] Allow any torch version for local builds by @vmoens in #764
  • [Versioning] Make dependence on uint16 optional for older PT versions by @vmoens in #839
  • [Versioning] tree_leaves for pytorch < 2.3 by @vmoens in #806
  • [Versioning] v0.5 bump by @vmoens in #848

New Contributors

Full Changelog: v0.4.0...v0.5.0