Skip to content

Commit

Permalink
polish merge.py;
Browse files Browse the repository at this point in the history
add unittest;
  • Loading branch information
Cloud-Pku committed Jun 6, 2023
1 parent 43645fe commit e0522c9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
1 change: 1 addition & 0 deletions ding/rl_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .acer import acer_policy_error, acer_value_error, acer_trust_region_update
from .sampler import ArgmaxSampler, MultinomialSampler, MuSampler, ReparameterizationSampler, HybridStochasticSampler, \
HybridDeterminsticSampler
from .merge import GatingType, SumMerge, VectorMerge
17 changes: 10 additions & 7 deletions ding/rl_utils/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GatingType(enum.Enum):


class SumMerge(nn.Module):
"""Merge streams using a simple sum (faster than Merge, for large stream).
"""Merge streams using a simple sum.
Streams must have the same size.
This module can merge any type of stream (vector, units or visual).
Expand All @@ -43,6 +43,15 @@ class VectorMerge(nn.Module):
If gating_type is not none, the sum is weighted using a softmax
of the intermediate activations labelled above.
Sepcifically,
GatingType.NONE:
Means simple addition of streams and the sum is not weighted based on gate features.
GatingType.GLOBAL:
Each data stream is weighted by a global gate value, and the sum of all global gate values is 1.
GatingType.POINTWISE:
Compared to GLOBAL, each value in the data stream feature tensor is weighted.
"""

def __init__(
Expand All @@ -58,15 +67,9 @@ def __init__(
input_sizes: A dictionary mapping input names to their size (a single
integer for 1d inputs, or None for 0d inputs).
If an input size is None, we assume it's ().
output_name: The name to give to the output of this module, of shape
[output_size] and dtype float32.
output_size: The size of the output vector.
gating_type: The type of gating mechanism to use.
use_layer_norm: Whether to use layer normalization.
input_dtypes: An optional dictionary with the dtypes of the inputs. If an
input is missing from this dictionary, its dtype is assumed to be
float32.
name: The name of this component.
"""
super().__init__()
self._input_sizes = OrderedDict(input_sizes)
Expand Down
34 changes: 34 additions & 0 deletions ding/rl_utils/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import torch
from ding.rl_utils import GatingType, SumMerge, VectorMerge


@pytest.mark.unittest
def test_SumMerge():
input_shape = (3, 5)
input = tuple([torch.rand(input_shape) for i in range(4)])
sum_merge = SumMerge()

output = sum_merge(input)
assert output.shape == (3, 5)


@pytest.mark.unittest
def test_VectorMerge():
input_sizes = {'in1': 3, 'in2': 16, 'in3': 27}
output_size = 512
input_dict = {}
for k, v in input_sizes.items():
input_dict[k] = torch.rand((64, v))

vector_merge = VectorMerge(input_sizes, output_size, GatingType.NONE)
output = vector_merge(input_dict)
assert output.shape == (64, output_size)

vector_merge = VectorMerge(input_sizes, output_size, GatingType.GLOBAL)
output = vector_merge(input_dict)
assert output.shape == (64, output_size)

vector_merge = VectorMerge(input_sizes, output_size, GatingType.POINTWISE)
output = vector_merge(input_dict)
assert output.shape == (64, output_size)

0 comments on commit e0522c9

Please sign in to comment.