From ce99cfd1fbec92ba9a0d6cc95f47be446668e1eb Mon Sep 17 00:00:00 2001 From: "Samyak K. G" Date: Tue, 15 Sep 2020 12:38:42 -0400 Subject: [PATCH] pytorch: Add conv2d wrapper --- .../library/compositions/pytorchcomponents.py | 17 ++++++++++++- tests/composition/test_autodiffcomposition.py | 25 ++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/psyneulink/library/compositions/pytorchcomponents.py b/psyneulink/library/compositions/pytorchcomponents.py index 99648fc81d1..82b99b0b218 100644 --- a/psyneulink/library/compositions/pytorchcomponents.py +++ b/psyneulink/library/compositions/pytorchcomponents.py @@ -3,7 +3,7 @@ from psyneulink.core import llvm as pnlvm from psyneulink.core.globals.log import LogCondition -from psyneulink.core.components.functions.transferfunctions import Linear, Logistic, ReLU +from psyneulink.core.components.functions.transferfunctions import Linear, Logistic, ReLU, Conv2d from psyneulink.library.compositions.pytorchllvmhelper import * __all__ = ['PytorchMechanismWrapper', 'PytorchProjectionWrapper'] @@ -103,6 +103,21 @@ def get_fct_param_value(param_name): wrapper = PytorchFunctionWrapper(func, device=device, context=context) return wrapper + elif isinstance(function, Conv2d): + kernel = get_fct_param_value('kernel') + kernel = torch.nn.Parameter(torch.tensor(np.reshape(kernel, (1, 1, *kernel.shape)), device=device, dtype=torch.double), requires_grad=True) + + stride = get_fct_param_value('stride') + padding = get_fct_param_value('padding') + dilation = get_fct_param_value('dilation') + + conv2d = torch.nn.functional.conv2d + def func(x): + x = torch.reshape(x, (1, 1, *x.shape)) + return conv2d(x, weight=kernel, stride=stride, padding=padding, dilation=dilation)[0][0] + + wrapper = PytorchFunctionWrapper(func, learnable_params = [kernel], device=device, context=context) + return wrapper else: raise Exception(f"Function {function} is not currently supported in AutodiffCompositions!") diff --git a/tests/composition/test_autodiffcomposition.py b/tests/composition/test_autodiffcomposition.py index 8fbd849b7b3..82ab656fd17 100644 --- a/tests/composition/test_autodiffcomposition.py +++ b/tests/composition/test_autodiffcomposition.py @@ -7,12 +7,13 @@ import psyneulink as pnl -from psyneulink.core.components.functions.transferfunctions import Logistic +from psyneulink.core.components.functions.transferfunctions import Logistic, Conv2d from psyneulink.core.components.functions.learningfunctions import BackPropagation from psyneulink.core.compositions.composition import Composition from psyneulink.core.globals import Context from psyneulink.core.globals.keywords import TRAINING_SET from psyneulink.core.components.mechanisms.processing.transfermechanism import TransferMechanism +from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition @@ -1218,6 +1219,28 @@ def test_pytorch_equivalence_with_autodiff_training_disabled_on_proj(self): assert np.allclose(output, comparator) + def test_conv2d_pytorch_equivalence_with_autodiff_composition(self): + variable, kernel, stride, padding, dilation, target, comparator = (np.ones((2, 2)), np.ones((2, 2)), (1,1), (0,0), (1,1), np.ones((1,1)), [[0.98340034484863281250]]) + + il = ProcessingMechanism(name='input', function=Conv2d(default_variable=variable, kernel=kernel, stride=stride, padding=padding, dilation=dilation), default_variable=variable) + comp = AutodiffComposition(optimizer_type='adam', learning_rate=1) + comp.add_node(il) + + input_set = { + 'inputs': { + il: [variable] + }, + 'targets': { + il: [target] + } + } + + results = comp.learn( + inputs=input_set, + epochs=100 + ) + + assert np.allclose(comparator, results[-1][-1]) @pytest.mark.pytorch @pytest.mark.actime