Skip to content

Commit

Permalink
[Feature] DETERMINISTIC interaction mode (#824)
Browse files Browse the repository at this point in the history
Co-authored-by: Matteo Bettini <[email protected]>
  • Loading branch information
vmoens and matteobettini authored Jun 20, 2024
1 parent d7ba913 commit d14db1c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
4 changes: 4 additions & 0 deletions tensordict/nn/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,7 @@ def mode(self) -> torch.Tensor:
@property
def mean(self) -> torch.Tensor:
return self.param

@property
def deterministic_sample(self) -> torch.Tensor:
return self.param
2 changes: 2 additions & 0 deletions tensordict/nn/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def mode(self) -> torch.Tensor:
else:
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)

determnistic_sample = mode

def sample(
self,
sample_shape: torch.Size | Sequence[int] | None = None,
Expand Down
40 changes: 34 additions & 6 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,24 @@


class InteractionType(Enum):
"""A list of possible interaction types with a distribution.
MODE, MEDIAN and MEAN point to the property / attribute with the same name.
RANDOM points to ``rsample()`` if that method exists or ``sample()`` if not.
DETERMINISTIC can be used as a generic fallback if ``MEAN`` or ``MODE`` are not guaranteed to
be analytically tractable. In such cases, a rude deterministic estimate can be used
in some cases even if it lacks a true algebraic meaning.
This value will trigger a query to the ``deterministic_sample`` attribute in the distribution
and if it does not exist, the ``mean`` will be used.
"""

MODE = auto()
MEDIAN = auto()
MEAN = auto()
RANDOM = auto()
DETERMINISTIC = auto()

@classmethod
def from_str(cls, type_str: str) -> InteractionType:
Expand Down Expand Up @@ -452,9 +466,27 @@ def _dist_sample(
if interaction_type is None:
interaction_type = self.default_interaction_type

if interaction_type is InteractionType.DETERMINISTIC:
try:
return dist.deterministic_sample
except AttributeError:
try:
return dist.mean
except AttributeError:
raise NotImplementedError(
f"method {type(dist)}.deterministic_sample is not implemented."
)
finally:
warnings.warn(
"deterministic_sample wasn't found when queried. "
f"{type(self).__name__} is falling back on mean instead. "
f"For better code quality and efficiency, make sure to either "
f"provide a distribution with a deterministic_sample attribute or "
f"to change the InteractionMode to the desired value.",
category=UserWarning,
)

if interaction_type is InteractionType.MODE:
if hasattr(dist, "get_mode"):
return dist.get_mode()
try:
return dist.mode
except AttributeError:
Expand All @@ -463,8 +495,6 @@ def _dist_sample(
)

elif interaction_type is InteractionType.MEDIAN:
if hasattr(dist, "get_median"):
return dist.get_median()
try:
return dist.median
except AttributeError:
Expand All @@ -473,8 +503,6 @@ def _dist_sample(
)

elif interaction_type is InteractionType.MEAN:
if hasattr(dist, "get_mean"):
return dist.get_mean()
try:
return dist.mean
except (AttributeError, NotImplementedError):
Expand Down

2 comments on commit d14db1c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: d14db1c Previous: d7ba913 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 74403.07569975217 iter/sec (stddev: 7.215903750179117e-7) 162351.08164522887 iter/sec (stddev: 4.554618447908789e-7) 2.18
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 74893.42368791699 iter/sec (stddev: 6.57667565889069e-7) 160675.5368826902 iter/sec (stddev: 4.1728210792213916e-7) 2.15

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: d14db1c Previous: d7ba913 Ratio
benchmarks/tensorclass/test_torch_functions.py::test_zeros_like 60.92872871302357 iter/sec (stddev: 0.00103245727128582) 126.61028994323557 iter/sec (stddev: 0.0003301679464965167) 2.08
benchmarks/tensorclass/test_torch_functions.py::test_ones_like 61.37947161134608 iter/sec (stddev: 0.00095755512893378) 125.79704426990067 iter/sec (stddev: 0.0003317291705128687) 2.05

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.