diff --git a/pudu/perturbation.py b/pudu/perturbation.py index a01e205..c1f6962 100644 --- a/pudu/perturbation.py +++ b/pudu/perturbation.py @@ -362,6 +362,23 @@ def apply(self, x, row, col, window, bias): return x, None +class CustomSum: + """ + Sums a custom vector as a perturbation to the array. + + :type func: callable + :param func: A function that takes a single argument and returns a single value. + + :rtype: 4d array + :return: Custom perturbated array + """ + def __init__(self, custom): + self.custom = custom + def apply(self, x, row, col, window, bias): + x[0, row:row+window[0], col:col+window[1], 0] = x[0, row:row+window[0], col:col+window[1], 0] + self.custom[0, row:row+window[0], col:col+window[1], 0] + return x, None + + class UpperThreshold: """ Sets values above a certain threshold to a specified value.