Skip to content

Commit

Permalink
pytorchcomponents: Add PytorchFunctionWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
SamKG committed Feb 2, 2021
1 parent 89d6cec commit 0b65504
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 38 deletions.
108 changes: 71 additions & 37 deletions psyneulink/library/compositions/pytorchcomponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,7 @@
from psyneulink.core.components.functions.transferfunctions import Linear, Logistic, ReLU
from psyneulink.library.compositions.pytorchllvmhelper import *

__all__ = ['PytorchMechanismWrapper', 'PytorchProjectionWrapper', 'wrap_mechanism']

def pytorch_function_creator(function, device, context=None):
"""
Converts a PsyNeuLink function into an equivalent PyTorch lambda function.
NOTE: This is needed due to PyTorch limitations (see: https://github.com/PrincetonUniversity/PsyNeuLink/pull/1657#discussion_r437489990)
"""
def get_fct_param_value(param_name):
val = function._get_current_parameter_value(
param_name, context=context)
if val is None:
val = getattr(function.defaults, param_name)

return float(val)

if isinstance(function, Linear):
slope = get_fct_param_value('slope')
intercept = get_fct_param_value('intercept')
return lambda x: x * slope + intercept

elif isinstance(function, Logistic):
gain = get_fct_param_value('gain')
bias = get_fct_param_value('bias')
offset = get_fct_param_value('offset')
return lambda x: 1 / (1 + torch.exp(-gain * (x + bias) + offset))

elif isinstance(function, ReLU):
gain = get_fct_param_value('gain')
bias = get_fct_param_value('bias')
leak = get_fct_param_value('leak')
return lambda x: (torch.max(input=(x - bias), other=torch.tensor([0], device=device).double()) * gain +
torch.min(input=(x - bias), other=torch.tensor([0], device=device).double()) * leak)
else:
raise Exception(f"Function {function} is not currently supported in AutodiffCompositions!")

def wrap_mechanism(mechanism, index, device, context=None):
return PytorchMechanismWrapper(mechanism, index, device, context=context)
__all__ = ['PytorchMechanismWrapper', 'PytorchProjectionWrapper']

class PytorchWrapper():
def _get_learnable_param_ids(self):
Expand Down Expand Up @@ -77,6 +41,71 @@ def _update_llvm_param_gradients(self, ctx, builder, state, params, node_delta_w
def _get_pytorch_params(self):
return []

class PytorchFunctionWrapper(PytorchWrapper):
"""
Wraps around a PsyNeuLink function; converts functionality to Pytorch
"""
def __init__(self, function, device, learnable_params=None, learnable_param_ids=None, context=None):
if learnable_params is None:
learnable_params = []
self._learnable_params = learnable_params

if learnable_param_ids is None:
learnable_param_ids = []
self._learnable_param_ids = learnable_param_ids

self._context = context
self._function = function
self._device = device

def __call__(self, variable):
return self._function(variable)

def _get_learnable_param_ids(self):
return self._learnable_param_ids

def _get_pytorch_params(self):
return self._learnable_params

def pytorch_function_creator(function, device, context=None):
"""
Converts a PsyNeuLink function into an equivalent PyTorch lambda function.
NOTE: This is needed due to PyTorch limitations (see: https://github.com/PrincetonUniversity/PsyNeuLink/pull/1657#discussion_r437489990)
"""
def get_fct_param_value(param_name):
val = getattr(function.parameters, param_name).get(context=context)
if val is None:
val = getattr(function.defaults, param_name)

return val

if isinstance(function, Linear):
slope = get_fct_param_value('slope')
intercept = get_fct_param_value('intercept')
func = lambda x:x * slope + intercept
wrapper = PytorchFunctionWrapper(func, device=device, context=context)
return wrapper

elif isinstance(function, Logistic):
gain = get_fct_param_value('gain')
bias = get_fct_param_value('bias')
offset = get_fct_param_value('offset')
func = lambda x: 1 / (1 + torch.exp(-gain * (x + bias) + offset))
wrapper = PytorchFunctionWrapper(func, device=device, context=context)
return wrapper

elif isinstance(function, ReLU):
gain = get_fct_param_value('gain')
bias = get_fct_param_value('bias')
leak = get_fct_param_value('leak')
func = lambda x: (torch.max(input=(x - bias), other=torch.tensor([0], device=device).double()) * gain +
torch.min(input=(x - bias), other=torch.tensor([0], device=device).double()) * leak)
wrapper = PytorchFunctionWrapper(func, device=device, context=context)
return wrapper

else:
raise Exception(f"Function {function} is not currently supported in AutodiffCompositions!")

class PytorchMechanismWrapper(PytorchWrapper):
"""
An interpretation of a mechanism as an equivalent pytorch object
Expand All @@ -103,6 +132,11 @@ def add_afferent(self, afferent):
assert afferent not in self.afferents
self.afferents.append(afferent)

def _get_learnable_param_ids(self):
return self.function._get_learnable_param_ids()

def _get_pytorch_params(self):
return self.function._get_pytorch_params()

def collate_afferents(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/library/compositions/pytorchmodelcreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, composition, device, context=None):

# Instantiate pytorch mechanisms
for node in set(composition.nodes) - set(composition.get_nodes_by_role(NodeRole.LEARNING)):
pytorch_node = wrap_mechanism(node, self._composition._get_node_index(node), device, context=context)
pytorch_node = PytorchMechanismWrapper(node, self._composition._get_node_index(node), device, context=context)
self.node_map[node] = pytorch_node
self.nodes.append(pytorch_node)
self.components.append(pytorch_node)
Expand Down

0 comments on commit 0b65504

Please sign in to comment.