We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Currently you cannot implement a custom sampling technique for the ActionDiscretizer transform.
ActionDiscretizer
Bring custom_arange out of transform_input_spec and make it a method of ActionDiscretizer. Wrappers around ActionDiscretizer can then update this.
custom_arange
transform_input_spec
class SamplingStrategy(IntEnum): MEDIAN = 0 LOW = 1 HIGH = 2 RANDOM = 3 def __init__( self, num_intervals: int | torch.Tensor, action_key: NestedKey = "action", out_action_key: NestedKey = None, sampling=None, categorical: bool = True, ): if out_action_key is None: out_action_key = action_key super().__init__(in_keys_inv=[action_key], out_keys_inv=[out_action_key]) self.action_key = action_key self.out_action_key = out_action_key if not isinstance(num_intervals, torch.Tensor): self.num_intervals = num_intervals else: self.register_buffer("num_intervals", num_intervals) if sampling is None: sampling = self.SamplingStrategy.MEDIAN self.sampling = sampling self.categorical = categorical def __repr__(self): def _indent(s): return indent(s, 4 * " ") num_intervals = f"num_intervals={self.num_intervals}" action_key = f"action_key={self.action_key}" out_action_key = f"out_action_key={self.out_action_key}" sampling = f"sampling={self.sampling}" categorical = f"categorical={self.categorical}" return ( f"{type(self).__name__}(\n{_indent(num_intervals)},\n{_indent(action_key)}," f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})" ) def _custom_arange(self, nint, device): result = torch.arange( start=0.0, end=1.0, step=1 / nint, dtype=self.dtype, device=device, ) result_ = result if self.sampling in ( self.SamplingStrategy.HIGH, self.SamplingStrategy.MEDIAN, ): result_ = (1 - result).flip(0) if self.sampling == self.SamplingStrategy.MEDIAN: result = (result + result_) / 2 else: result = result_ return result def transform_input_spec(self, input_spec): try: action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] if not isinstance(action_spec, Bounded): raise TypeError( f"action spec type {type(action_spec)} is not supported." ) n_act = action_spec.shape if not n_act: n_act = 1 else: n_act = n_act[-1] self.n_act = n_act self.dtype = action_spec.dtype interval = (action_spec.high - action_spec.low).unsqueeze(-1) num_intervals = self.num_intervals if isinstance(num_intervals, int): arange = ( self._custom_arange(num_intervals, action_spec.device).expand( n_act, num_intervals ) * interval ) self.register_buffer( "intervals", action_spec.low.unsqueeze(-1) + arange ) else: arange = [ self._custom_arange(_num_intervals, action_spec.device) * interval for _num_intervals, interval in zip( num_intervals.tolist(), interval.unbind(-2) ) ] self.intervals = [ low + arange for low, arange in zip( action_spec.low.unsqueeze(-1).unbind(-2), arange ) ]
The text was updated successfully, but these errors were encountered:
vmoens
No branches or pull requests
Motivation
Currently you cannot implement a custom sampling technique for the
ActionDiscretizer
transform.Solution
Bring
custom_arange
out oftransform_input_spec
and make it a method ofActionDiscretizer
. Wrappers aroundActionDiscretizer
can then update this.Alternatives
Additional context
Checklist
The text was updated successfully, but these errors were encountered: